52 std::cout <<
"AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
53 return STATUS_CODE_ALREADY_INITIALIZED;
60 std::cout <<
"AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
61 return STATUS_CODE_INVALID_PARAMETER;
67 while (pContainerXmlNode)
69 if (pContainerXmlNode->
ValueStr() !=
"AdaBoostDecisionTree")
70 return STATUS_CODE_FAILURE;
74 std::string currentName;
77 if (currentName.empty() || (currentName.size() > 1000))
79 std::cout <<
"AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
80 return STATUS_CODE_INVALID_PARAMETER;
83 if (currentName == bdtName)
86 pContainerXmlNode = pContainerXmlNode->
NextSibling();
89 if (!pContainerXmlNode)
91 std::cout <<
"AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
92 return STATUS_CODE_NOT_FOUND;
105 if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.
GetStatusCode())
106 std::cout <<
"AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
108 if (STATUS_CODE_FAILURE == statusCodeException.
GetStatusCode())
109 std::cout <<
"AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
114 return STATUS_CODE_SUCCESS;
146 std::cout <<
"AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
157 if (STATUS_CODE_NOT_FOUND == statusCodeException.
GetStatusCode())
159 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
161 else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.
GetStatusCode())
163 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier."
166 else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.
GetStatusCode())
168 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
172 std::cout <<
"AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
175 throw statusCodeException;
185 m_leftChildNodeId(0),
186 m_rightChildNodeId(0),
201 if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202 STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
207 else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
224 m_nodeId(rhs.m_nodeId),
225 m_parentNodeId(rhs.m_parentNodeId),
226 m_leftChildNodeId(rhs.m_leftChildNodeId),
227 m_rightChildNodeId(rhs.m_rightChildNodeId),
228 m_isLeaf(rhs.m_isLeaf),
229 m_threshold(rhs.m_threshold),
230 m_variableId(rhs.m_variableId),
231 m_outcome(rhs.m_outcome)
266 pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
268 if (
"TreeIndex" == pHeadTiXmlElement->ValueStr())
272 else if (
"TreeWeight" == pHeadTiXmlElement->ValueStr())
276 else if (
"Node" == pHeadTiXmlElement->ValueStr())
279 const Node *pNode =
new Node(&nodeHandle);
291 const Node *pNode =
new Node(*(mapEntry.second));
304 const Node *pNode =
new Node(*(mapEntry.second));
305 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->
GetNodeId(), pNode));
319 for (
const auto &mapEntry : m_idToNodeMap)
320 delete mapEntry.second;
327 return this->EvaluateNode(0, features);
334 const Node *pActiveNode(
nullptr);
336 if (m_idToNodeMap.find(nodeId) != m_idToNodeMap.end())
338 pActiveNode = m_idToNodeMap.at(nodeId);
345 if (pActiveNode->
IsLeaf())
348 if (
static_cast<int>(features.size()) <= pActiveNode->
GetVariableId())
368 while (pCurrentXmlElement)
370 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
382 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
392 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
402 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
403 delete pWeakClassifier;
410 double score(0.), weights(0.);
412 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
414 weights += pWeakClassifier->GetWeight();
416 if (pWeakClassifier->Predict(features))
418 score += pWeakClassifier->GetWeight();
422 score -= pWeakClassifier->GetWeight();
426 if (weights > std::numeric_limits<double>::epsilon())
442 const std::string componentName(pCurrentXmlElement->
ValueStr());
445 if ((std::string(
"Name") == componentName) || (std::string(
"Timestamp") == componentName))
446 return STATUS_CODE_SUCCESS;
448 if (std::string(
"DecisionTree") == componentName)
450 m_weakClassifiers.emplace_back(
new WeakClassifier(¤tHandle));
451 return STATUS_CODE_SUCCESS;
454 return STATUS_CODE_INVALID_PARAMETER;
Header file for the lar adaptive boosted decision tree class.
#define PANDORA_THROW_RESULT_IF(StatusCode1, Operator, Command)
Header file for the xml helper class.
Node class used for representing a decision tree.
double GetThreshold() const
Return node threshold.
int m_variableId
Variable cut on for decision if decision node.
double m_threshold
Threshold used for decision if decision node.
int GetVariableId() const
Return cut variable.
int GetNodeId() const
Return node id.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetLeftChildNodeId() const
Return left child node id.
int m_rightChildNodeId
Right child node id.
bool IsLeaf() const
Return is the node a leaf.
int m_parentNodeId
Parent node id.
int GetRightChildNodeId() const
Return right child node id.
bool GetOutcome() const
Return outcome.
Node & operator=(const Node &rhs)
Assignment operator.
bool m_isLeaf
Is node a leaf.
int m_leftChildNodeId
Left child node id.
bool m_outcome
Outcome if leaf node.
StrongClassifier class used in application of adaptive boost decision tree.
~StrongClassifier()
Destructor.
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
WeakClassifier class containing a decision tree and a weight.
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double m_weight
Boost weight.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
IdToNodeMap m_idToNodeMap
Decision tree nodes.
int m_treeId
Decision tree id.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
~WeakClassifier()
Destructor.
AdaBoostDecisionTree class.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
~AdaBoostDecisionTree()
Destructor.
AdaBoostDecisionTree()
Constructor.
pandora::StatusCode Initialize(const std::string ¶meterLocation, const std::string &bdtName)
Initialize the bdt model.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
MvaTypes::MvaFeatureVector MvaFeatureVector
StatusCodeException class.
StatusCode GetStatusCode() const
Get status code.
bool LoadFile(TiXmlEncoding encoding=TIXML_DEFAULT_ENCODING)
TiXmlHandle FirstChildElement() const
Return a handle to the first child element.
TiXmlElement * Element() const
TiXmlHandle FirstChild() const
Return a handle to the first child node.
TiXmlElement * ToElement() const
const std::string & ValueStr() const
const TiXmlElement * NextSiblingElement() const
const TiXmlNode * NextSibling(const std::string &_value) const
STL std::string form.
static StatusCode ReadValue(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, T &t)
Read a value from an xml element.
StatusCode
The StatusCode enum.