Pandora
Pandora source code navigator
Loading...
Searching...
No Matches
LArAdaBoostDecisionTree.cc
Go to the documentation of this file.
1
9#include "Helpers/XmlHelper.h"
10
12
13using namespace pandora;
14
15namespace lar_content
16{
17
18AdaBoostDecisionTree::AdaBoostDecisionTree() : m_pStrongClassifier(nullptr)
19{
20}
21
22//------------------------------------------------------------------------------------------------------------------------------------------
23
28
29//------------------------------------------------------------------------------------------------------------------------------------------
30
32{
33 if (this != &rhs)
35
36 return *this;
37}
38
39//------------------------------------------------------------------------------------------------------------------------------------------
40
45
46//------------------------------------------------------------------------------------------------------------------------------------------
47
48StatusCode AdaBoostDecisionTree::Initialize(const std::string &bdtXmlFileName, const std::string &bdtName)
49{
51 {
52 std::cout << "AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
53 return STATUS_CODE_ALREADY_INITIALIZED;
54 }
55
56 TiXmlDocument xmlDocument(bdtXmlFileName);
57
58 if (!xmlDocument.LoadFile())
59 {
60 std::cout << "AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
61 return STATUS_CODE_INVALID_PARAMETER;
62 }
63
64 const TiXmlHandle xmlDocumentHandle(&xmlDocument);
65 TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
66
67 while (pContainerXmlNode)
68 {
69 if (pContainerXmlNode->ValueStr() != "AdaBoostDecisionTree")
70 return STATUS_CODE_FAILURE;
71
72 const TiXmlHandle currentHandle(pContainerXmlNode);
73
74 std::string currentName;
75 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
76
77 if (currentName.empty() || (currentName.size() > 1000))
78 {
79 std::cout << "AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
80 return STATUS_CODE_INVALID_PARAMETER;
81 }
82
83 if (currentName == bdtName)
84 break;
85
86 pContainerXmlNode = pContainerXmlNode->NextSibling();
87 }
88
89 if (!pContainerXmlNode)
90 {
91 std::cout << "AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
92 return STATUS_CODE_NOT_FOUND;
93 }
94
95 const TiXmlHandle xmlHandle(pContainerXmlNode);
96
97 try
98 {
99 m_pStrongClassifier = new StrongClassifier(&xmlHandle);
100 }
101 catch (StatusCodeException &statusCodeException)
102 {
103 delete m_pStrongClassifier;
104
105 if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
106 std::cout << "AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
107
108 if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
109 std::cout << "AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
110
111 return statusCodeException.GetStatusCode();
112 }
113
114 return STATUS_CODE_SUCCESS;
115}
116
117//------------------------------------------------------------------------------------------------------------------------------------------
118
120{
121 return ((this->CalculateScore(features) > 0.) ? true : false);
122}
123
124//------------------------------------------------------------------------------------------------------------------------------------------
125
127{
128 return this->CalculateScore(features);
129}
130
131//------------------------------------------------------------------------------------------------------------------------------------------
132
134{
135 // ATTN: BDT score, once normalised by total weight, is confined to the range -1 to +1. This linear mapping places the score in the
136 // range 0 to 1 so that it may be interpreted as a probability.
137 return (this->CalculateScore(features) + 1.) * 0.5;
138}
139
140//------------------------------------------------------------------------------------------------------------------------------------------
141
143{
145 {
146 std::cout << "AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
147 throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
148 }
149
150 try
151 {
152 // TODO: Add consistency check for number of features, bearing in mind not all features in a bdt may be used
153 return m_pStrongClassifier->Predict(features);
154 }
155 catch (StatusCodeException &statusCodeException)
156 {
157 if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
158 {
159 std::cout << "AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
160 }
161 else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
162 {
163 std::cout << "AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier."
164 << std::endl;
165 }
166 else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
167 {
168 std::cout << "AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
169 }
170 else
171 {
172 std::cout << "AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
173 }
174
175 throw statusCodeException;
176 }
177}
178
179//------------------------------------------------------------------------------------------------------------------------------------------
180//------------------------------------------------------------------------------------------------------------------------------------------
181
183 m_nodeId(0),
184 m_parentNodeId(0),
185 m_leftChildNodeId(0),
186 m_rightChildNodeId(0),
187 m_isLeaf(false),
188 m_threshold(0.),
189 m_variableId(0),
190 m_outcome(false)
191{
192 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "NodeId", m_nodeId));
193 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "ParentNodeId", m_parentNodeId));
194
195 const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "LeftChildNodeId", m_leftChildNodeId));
196 const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "RightChildNodeId", m_rightChildNodeId));
197 const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Threshold", m_threshold));
198 const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "VariableId", m_variableId));
199 const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Outcome", m_outcome));
200
201 if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202 STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
203 {
204 m_isLeaf = false;
205 m_outcome = false;
206 }
207 else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
208 {
209 m_isLeaf = true;
210 m_leftChildNodeId = std::numeric_limits<int>::max();
211 m_rightChildNodeId = std::numeric_limits<int>::max();
212 m_threshold = std::numeric_limits<double>::max();
213 m_variableId = std::numeric_limits<int>::max();
214 }
215 else
216 {
217 throw StatusCodeException(STATUS_CODE_FAILURE);
218 }
219}
220
221//------------------------------------------------------------------------------------------------------------------------------------------
222
224 m_nodeId(rhs.m_nodeId),
225 m_parentNodeId(rhs.m_parentNodeId),
226 m_leftChildNodeId(rhs.m_leftChildNodeId),
227 m_rightChildNodeId(rhs.m_rightChildNodeId),
228 m_isLeaf(rhs.m_isLeaf),
229 m_threshold(rhs.m_threshold),
230 m_variableId(rhs.m_variableId),
231 m_outcome(rhs.m_outcome)
232{
233}
234
235//------------------------------------------------------------------------------------------------------------------------------------------
236
238{
239 if (this != &rhs)
240 {
241 m_nodeId = rhs.m_nodeId;
242 m_parentNodeId = rhs.m_parentNodeId;
243 m_leftChildNodeId = rhs.m_leftChildNodeId;
244 m_rightChildNodeId = rhs.m_rightChildNodeId;
245 m_isLeaf = rhs.m_isLeaf;
246 m_threshold = rhs.m_threshold;
247 m_variableId = rhs.m_variableId;
248 m_outcome = rhs.m_outcome;
249 }
250
251 return *this;
252}
253
254//------------------------------------------------------------------------------------------------------------------------------------------
255
259
260//------------------------------------------------------------------------------------------------------------------------------------------
261//------------------------------------------------------------------------------------------------------------------------------------------
262
263AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const TiXmlHandle *const pXmlHandle) : m_weight(0.), m_treeId(0)
264{
265 for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL;
266 pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
267 {
268 if ("TreeIndex" == pHeadTiXmlElement->ValueStr())
269 {
270 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeIndex", m_treeId));
271 }
272 else if ("TreeWeight" == pHeadTiXmlElement->ValueStr())
273 {
274 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeWeight", m_weight));
275 }
276 else if ("Node" == pHeadTiXmlElement->ValueStr())
277 {
278 const TiXmlHandle nodeHandle(pHeadTiXmlElement);
279 const Node *pNode = new Node(&nodeHandle);
280 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
281 }
282 }
283}
284
285//------------------------------------------------------------------------------------------------------------------------------------------
286
287AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const WeakClassifier &rhs) : m_weight(rhs.m_weight), m_treeId(rhs.m_treeId)
288{
289 for (const auto &mapEntry : rhs.m_idToNodeMap)
290 {
291 const Node *pNode = new Node(*(mapEntry.second));
292 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
293 }
294}
295
296//------------------------------------------------------------------------------------------------------------------------------------------
297
299{
300 if (this != &rhs)
301 {
302 for (const auto &mapEntry : rhs.m_idToNodeMap)
303 {
304 const Node *pNode = new Node(*(mapEntry.second));
305 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
306 }
307
308 m_weight = rhs.m_weight;
309 m_treeId = rhs.m_treeId;
310 }
311
312 return *this;
313}
314
315//------------------------------------------------------------------------------------------------------------------------------------------
316
318{
319 for (const auto &mapEntry : m_idToNodeMap)
320 delete mapEntry.second;
321}
322
323//------------------------------------------------------------------------------------------------------------------------------------------
324
326{
327 return this->EvaluateNode(0, features);
328}
329
330//------------------------------------------------------------------------------------------------------------------------------------------
331
333{
334 const Node *pActiveNode(nullptr);
335
336 if (m_idToNodeMap.find(nodeId) != m_idToNodeMap.end())
337 {
338 pActiveNode = m_idToNodeMap.at(nodeId);
339 }
340 else
341 {
342 throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
343 }
344
345 if (pActiveNode->IsLeaf())
346 return pActiveNode->GetOutcome();
347
348 if (static_cast<int>(features.size()) <= pActiveNode->GetVariableId())
349 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
350
351 if (features.at(pActiveNode->GetVariableId()).Get() <= pActiveNode->GetThreshold())
352 {
353 return this->EvaluateNode(pActiveNode->GetLeftChildNodeId(), features);
354 }
355 else
356 {
357 return this->EvaluateNode(pActiveNode->GetRightChildNodeId(), features);
358 }
359}
360
361//------------------------------------------------------------------------------------------------------------------------------------------
362//------------------------------------------------------------------------------------------------------------------------------------------
363
365{
366 TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
367
368 while (pCurrentXmlElement)
369 {
370 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
371 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
372
373 pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
374 }
375}
376
377//------------------------------------------------------------------------------------------------------------------------------------------
378
380{
381 for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
382 m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
383}
384
385//------------------------------------------------------------------------------------------------------------------------------------------
386
388{
389 if (this != &rhs)
390 {
391 for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
392 m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
393 }
394
395 return *this;
396}
397
398//------------------------------------------------------------------------------------------------------------------------------------------
399
401{
402 for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
403 delete pWeakClassifier;
404}
405
406//------------------------------------------------------------------------------------------------------------------------------------------
407
409{
410 double score(0.), weights(0.);
411
412 for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
413 {
414 weights += pWeakClassifier->GetWeight();
415
416 if (pWeakClassifier->Predict(features))
417 {
418 score += pWeakClassifier->GetWeight();
419 }
420 else
421 {
422 score -= pWeakClassifier->GetWeight();
423 }
424 }
425
426 if (weights > std::numeric_limits<double>::epsilon())
427 {
428 score /= weights;
429 }
430 else
431 {
432 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
433 }
434
435 return score;
436}
437
438//------------------------------------------------------------------------------------------------------------------------------------------
439
441{
442 const std::string componentName(pCurrentXmlElement->ValueStr());
443 TiXmlHandle currentHandle(pCurrentXmlElement);
444
445 if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
446 return STATUS_CODE_SUCCESS;
447
448 if (std::string("DecisionTree") == componentName)
449 {
450 m_weakClassifiers.emplace_back(new WeakClassifier(&currentHandle));
451 return STATUS_CODE_SUCCESS;
452 }
453
454 return STATUS_CODE_INVALID_PARAMETER;
455}
456
457} // namespace lar_content
Header file for the lar adaptive boosted decision tree class.
#define PANDORA_THROW_RESULT_IF(StatusCode1, Operator, Command)
Definition StatusCodes.h:43
Header file for the xml helper class.
Node class used for representing a decision tree.
double GetThreshold() const
Return node threshold.
int m_variableId
Variable cut on for decision if decision node.
double m_threshold
Threshold used for decision if decision node.
int GetVariableId() const
Return cut variable.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetLeftChildNodeId() const
Return left child node id.
bool IsLeaf() const
Return is the node a leaf.
int GetRightChildNodeId() const
Return right child node id.
Node & operator=(const Node &rhs)
Assignment operator.
StrongClassifier class used in application of adaptive boost decision tree.
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
WeakClassifier class containing a decision tree and a weight.
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
MvaTypes::MvaFeatureVector MvaFeatureVector
StatusCodeException class.
StatusCode GetStatusCode() const
Get status code.
bool LoadFile(TiXmlEncoding encoding=TIXML_DEFAULT_ENCODING)
Definition tinyxml.cc:957
TiXmlHandle FirstChildElement() const
Return a handle to the first child element.
Definition tinyxml.cc:1659
TiXmlElement * Element() const
Definition tinyxml.h:1710
TiXmlHandle FirstChild() const
Return a handle to the first child node.
Definition tinyxml.cc:1635
TiXmlElement * ToElement() const
Definition tinyxml.h:1695
const std::string & ValueStr() const
Definition tinyxml.h:501
const TiXmlElement * NextSiblingElement() const
Definition tinyxml.cc:485
const TiXmlNode * NextSibling(const std::string &_value) const
STL std::string form.
Definition tinyxml.h:633
static StatusCode ReadValue(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, T &t)
Read a value from an xml element.
Definition XmlHelper.h:136
StatusCode
The StatusCode enum.