Pandora
Pandora source code navigator
Loading...
Searching...
No Matches
LArSupportVectorMachine.cc
Go to the documentation of this file.
1
9#include "Helpers/XmlHelper.h"
10
12
13using namespace pandora;
14
15namespace lar_content
16{
17
19 m_isInitialized(false),
20 m_enableProbability(false),
21 m_probAParameter(0.),
22 m_probBParameter(0.),
23 m_standardizeFeatures(true),
24 m_nFeatures(0),
25 m_bias(0.),
26 m_scaleFactor(1.),
27 m_kernelType(QUADRATIC),
28 m_kernelFunction(QuadraticKernel),
29 m_kernelMap{{LINEAR, LinearKernel}, {QUADRATIC, QuadraticKernel}, {CUBIC, CubicKernel}, {GAUSSIAN_RBF, GaussianRbfKernel}}
30{
31}
32
33//------------------------------------------------------------------------------------------------------------------------------------------
34
35StatusCode SupportVectorMachine::Initialize(const std::string &parameterLocation, const std::string &svmName)
36{
37 if (m_isInitialized)
38 {
39 std::cout << "SupportVectorMachine: svm was already initialized" << std::endl;
40 return STATUS_CODE_ALREADY_INITIALIZED;
41 }
42
43 this->ReadXmlFile(parameterLocation, svmName);
44
45 // Check the sizes of sigma and scale factor if they are to be used as divisors
46 if (m_standardizeFeatures)
47 {
48 for (const FeatureInfo &featureInfo : m_featureInfoList)
49 {
50 if (featureInfo.m_sigmaValue < std::numeric_limits<double>::epsilon())
51 {
52 std::cout << "SupportVectorMachine: could not standardize parameters because sigma value was too small" << std::endl;
53 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
54 }
55 }
56 }
57
58 // Check the number of features is consistent.
59 m_nFeatures = m_featureInfoList.size();
60
61 for (const SupportVectorInfo &svInfo : m_svInfoList)
62 {
63 if (svInfo.m_supportVector.size() != m_nFeatures)
64 {
65 std::cout << "SupportVectorMachine: the number of features in the xml file was inconsistent" << std::endl;
66 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
67 }
68 }
69
70 // There's the possibility of a user-defined kernel that doesn't use this as a divisor but let's be safe
71 if (m_scaleFactor < std::numeric_limits<double>::epsilon())
72 {
73 std::cout << "SupportVectorMachine: could not evaluate kernel because scale factor was too small" << std::endl;
74 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
75 }
76
77 m_isInitialized = true;
78 return STATUS_CODE_SUCCESS;
79}
80
81//------------------------------------------------------------------------------------------------------------------------------------------
82
83void SupportVectorMachine::ReadXmlFile(const std::string &svmFileName, const std::string &svmName)
84{
85 TiXmlDocument xmlDocument(svmFileName);
86
87 if (!xmlDocument.LoadFile())
88 {
89 std::cout << "SupportVectorMachine::Initialize - Invalid xml file." << std::endl;
90 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
91 }
92
93 const TiXmlHandle xmlDocumentHandle(&xmlDocument);
94 TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
95
96 // Try to find the svm container with the required name
97 while (pContainerXmlNode)
98 {
99 if (pContainerXmlNode->ValueStr() != "SupportVectorMachine")
100 throw StatusCodeException(STATUS_CODE_FAILURE);
101
102 const TiXmlHandle currentHandle(pContainerXmlNode);
103
104 std::string currentName;
105 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
106
107 if (currentName.empty() || (currentName.size() > 1000))
108 {
109 std::cout << "SupportVectorMachine::Initialize - Implausible svm name extracted from xml." << std::endl;
110 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
111 }
112
113 if (currentName == svmName)
114 break;
115
116 pContainerXmlNode = pContainerXmlNode->NextSibling();
117 }
118
119 if (!pContainerXmlNode)
120 {
121 std::cout << "SupportVectorMachine: Could not find an svm by the name " << svmName << std::endl;
122 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
123 }
124
125 // Read the components of this svm container
126 TiXmlHandle localHandle(pContainerXmlNode);
127 TiXmlElement *pCurrentXmlElement = localHandle.FirstChild().Element();
128
129 while (pCurrentXmlElement)
130 {
131 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
132 {
133 std::cout << "SupportVectorMachine: Unknown component in xml file" << std::endl;
134 throw StatusCodeException(STATUS_CODE_FAILURE);
135 }
136
137 pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
138 }
139}
140
141//------------------------------------------------------------------------------------------------------------------------------------------
142
143StatusCode SupportVectorMachine::ReadComponent(TiXmlElement *pCurrentXmlElement)
144{
145 const std::string componentName(pCurrentXmlElement->ValueStr());
146 const TiXmlHandle currentHandle(pCurrentXmlElement);
147
148 if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
149 return STATUS_CODE_SUCCESS;
150
151 if (std::string("Machine") == componentName)
152 return this->ReadMachine(currentHandle);
153
154 if (std::string("Features") == componentName)
155 return this->ReadFeatures(currentHandle);
156
157 if (std::string("SupportVector") == componentName)
158 return this->ReadSupportVector(currentHandle);
159
160 return STATUS_CODE_INVALID_PARAMETER;
161}
162
163//------------------------------------------------------------------------------------------------------------------------------------------
164
165StatusCode SupportVectorMachine::ReadMachine(const TiXmlHandle &currentHandle)
166{
167 int kernelType(0);
168 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "KernelType", kernelType));
169
170 double bias(0.);
171 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "Bias", bias));
172
173 double scaleFactor(0.);
174 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "ScaleFactor", scaleFactor));
175
176 bool standardize(true);
177 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "Standardize", standardize));
178
179 bool enableProbability(false);
181 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "EnableProbability", enableProbability));
182
183 double probAParameter(0.);
184 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "ProbAParameter", probAParameter));
185
186 double probBParameter(0.);
187 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "ProbBParameter", probBParameter));
188
189 m_kernelType = static_cast<KernelType>(kernelType);
190 m_bias = bias;
191 m_scaleFactor = scaleFactor;
192 m_enableProbability = enableProbability;
193 m_probAParameter = probAParameter;
194 m_probBParameter = probBParameter;
195
196 if (kernelType != USER_DEFINED) // if user-defined, leave it so it alone can be set before/after initialization
197 m_kernelFunction = m_kernelMap.at(m_kernelType);
198
199 return STATUS_CODE_SUCCESS;
200}
201
202//------------------------------------------------------------------------------------------------------------------------------------------
203
204StatusCode SupportVectorMachine::ReadFeatures(const TiXmlHandle &currentHandle)
205{
206 std::vector<double> muValues;
207 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(currentHandle, "MuValues", muValues));
208
209 std::vector<double> sigmaValues;
211 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(currentHandle, "SigmaValues", sigmaValues));
212
213 if (muValues.size() != sigmaValues.size())
214 {
215 std::cout << "SupportVectorMachine: could not add feature info because the size of mu (" << muValues.size()
216 << ") did not match "
217 "the size of sigma ("
218 << sigmaValues.size() << ")" << std::endl;
219 return STATUS_CODE_INVALID_PARAMETER;
220 }
221
222 m_featureInfoList.reserve(muValues.size());
223
224 for (std::size_t i = 0; i < muValues.size(); ++i)
225 m_featureInfoList.emplace_back(muValues.at(i), sigmaValues.at(i));
226
227 return STATUS_CODE_SUCCESS;
228}
229
230//------------------------------------------------------------------------------------------------------------------------------------------
231
232StatusCode SupportVectorMachine::ReadSupportVector(const TiXmlHandle &currentHandle)
233{
234 double yAlpha(0.0);
235 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "AlphaY", yAlpha));
236
237 std::vector<double> values;
238 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(currentHandle, "Values", values));
239
240 LArMvaHelper::MvaFeatureVector valuesFeatureVector;
241 for (const double &value : values)
242 valuesFeatureVector.emplace_back(value);
243
244 m_svInfoList.emplace_back(yAlpha, valuesFeatureVector);
245 return STATUS_CODE_SUCCESS;
246}
247
248//------------------------------------------------------------------------------------------------------------------------------------------
249
250double SupportVectorMachine::CalculateClassificationScoreImpl(const LArMvaHelper::MvaFeatureVector &features) const
251{
252 if (!m_isInitialized)
253 {
254 std::cout << "SupportVectorMachine: could not perform classification because the svm was uninitialized" << std::endl;
255 throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
256 }
257
258 if (m_svInfoList.empty())
259 {
260 std::cout << "SupportVectorMachine: could not perform classification because the initialized svm had no support vectors in the model"
261 << std::endl;
262 throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
263 }
264
265 LArMvaHelper::MvaFeatureVector standardizedFeatures;
266 standardizedFeatures.reserve(m_nFeatures);
267
268 if (m_standardizeFeatures)
269 {
270 for (std::size_t i = 0; i < m_nFeatures; ++i)
271 standardizedFeatures.push_back(m_featureInfoList.at(i).StandardizeParameter(features.at(i).Get()));
272 }
273
274 double classScore(0.);
275 for (const SupportVectorInfo &supportVectorInfo : m_svInfoList)
276 {
277 classScore += supportVectorInfo.m_yAlpha * m_kernelFunction(supportVectorInfo.m_supportVector,
278 (m_standardizeFeatures ? standardizedFeatures : features), m_scaleFactor);
279 }
280
281 return classScore + m_bias;
282}
283
284} // namespace lar_content
Header file for the lar support vector machine class.
#define PANDORA_THROW_RESULT_IF(StatusCode1, Operator, Command)
Definition StatusCodes.h:43
#define PANDORA_RETURN_RESULT_IF_AND_IF(StatusCode1, StatusCode2, Operator, Command)
Definition StatusCodes.h:31
Header file for the xml helper class.
MvaTypes::MvaFeatureVector MvaFeatureVector
static double LinearKernel(const LArMvaHelper::MvaFeatureVector &supportVector, const LArMvaHelper::MvaFeatureVector &features, const double scaleFactor=1.)
A linear kernel.
StatusCodeException class.
bool LoadFile(TiXmlEncoding encoding=TIXML_DEFAULT_ENCODING)
Definition tinyxml.cc:957
TiXmlElement * Element() const
Definition tinyxml.h:1710
TiXmlHandle FirstChild() const
Return a handle to the first child node.
Definition tinyxml.cc:1635
const std::string & ValueStr() const
Definition tinyxml.h:501
const TiXmlElement * NextSiblingElement() const
Definition tinyxml.cc:485
const TiXmlNode * NextSibling(const std::string &_value) const
STL std::string form.
Definition tinyxml.h:633
static StatusCode ReadVectorOfValues(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, std::vector< T > &vector)
Read a vector of values from a (space separated) list in an xml element.
Definition XmlHelper.h:229
static StatusCode ReadValue(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, T &t)
Read a value from an xml element.
Definition XmlHelper.h:136
StatusCode
The StatusCode enum.