19 m_isInitialized(false),
20 m_enableProbability(false),
23 m_standardizeFeatures(true),
27 m_kernelType(QUADRATIC),
28 m_kernelFunction(QuadraticKernel),
29 m_kernelMap{{
LINEAR,
LinearKernel}, {QUADRATIC, QuadraticKernel}, {CUBIC, CubicKernel}, {GAUSSIAN_RBF, GaussianRbfKernel}}
35StatusCode SupportVectorMachine::Initialize(
const std::string ¶meterLocation,
const std::string &svmName)
39 std::cout <<
"SupportVectorMachine: svm was already initialized" << std::endl;
40 return STATUS_CODE_ALREADY_INITIALIZED;
43 this->ReadXmlFile(parameterLocation, svmName);
46 if (m_standardizeFeatures)
48 for (
const FeatureInfo &featureInfo : m_featureInfoList)
50 if (featureInfo.m_sigmaValue < std::numeric_limits<double>::epsilon())
52 std::cout <<
"SupportVectorMachine: could not standardize parameters because sigma value was too small" << std::endl;
59 m_nFeatures = m_featureInfoList.size();
63 if (svInfo.m_supportVector.size() != m_nFeatures)
65 std::cout <<
"SupportVectorMachine: the number of features in the xml file was inconsistent" << std::endl;
71 if (m_scaleFactor < std::numeric_limits<double>::epsilon())
73 std::cout <<
"SupportVectorMachine: could not evaluate kernel because scale factor was too small" << std::endl;
77 m_isInitialized =
true;
78 return STATUS_CODE_SUCCESS;
83void SupportVectorMachine::ReadXmlFile(
const std::string &svmFileName,
const std::string &svmName)
89 std::cout <<
"SupportVectorMachine::Initialize - Invalid xml file." << std::endl;
97 while (pContainerXmlNode)
99 if (pContainerXmlNode->
ValueStr() !=
"SupportVectorMachine")
102 const TiXmlHandle currentHandle(pContainerXmlNode);
104 std::string currentName;
107 if (currentName.empty() || (currentName.size() > 1000))
109 std::cout <<
"SupportVectorMachine::Initialize - Implausible svm name extracted from xml." << std::endl;
113 if (currentName == svmName)
116 pContainerXmlNode = pContainerXmlNode->
NextSibling();
119 if (!pContainerXmlNode)
121 std::cout <<
"SupportVectorMachine: Could not find an svm by the name " << svmName << std::endl;
129 while (pCurrentXmlElement)
131 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
133 std::cout <<
"SupportVectorMachine: Unknown component in xml file" << std::endl;
145 const std::string componentName(pCurrentXmlElement->
ValueStr());
146 const TiXmlHandle currentHandle(pCurrentXmlElement);
148 if ((std::string(
"Name") == componentName) || (std::string(
"Timestamp") == componentName))
149 return STATUS_CODE_SUCCESS;
151 if (std::string(
"Machine") == componentName)
152 return this->ReadMachine(currentHandle);
154 if (std::string(
"Features") == componentName)
155 return this->ReadFeatures(currentHandle);
157 if (std::string(
"SupportVector") == componentName)
158 return this->ReadSupportVector(currentHandle);
160 return STATUS_CODE_INVALID_PARAMETER;
173 double scaleFactor(0.);
176 bool standardize(
true);
179 bool enableProbability(
false);
181 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=,
XmlHelper::ReadValue(currentHandle,
"EnableProbability", enableProbability));
183 double probAParameter(0.);
186 double probBParameter(0.);
189 m_kernelType =
static_cast<KernelType>(kernelType);
191 m_scaleFactor = scaleFactor;
192 m_enableProbability = enableProbability;
193 m_probAParameter = probAParameter;
194 m_probBParameter = probBParameter;
196 if (kernelType != USER_DEFINED)
197 m_kernelFunction = m_kernelMap.at(m_kernelType);
199 return STATUS_CODE_SUCCESS;
206 std::vector<double> muValues;
209 std::vector<double> sigmaValues;
213 if (muValues.size() != sigmaValues.size())
215 std::cout <<
"SupportVectorMachine: could not add feature info because the size of mu (" << muValues.size()
216 <<
") did not match "
217 "the size of sigma ("
218 << sigmaValues.size() <<
")" << std::endl;
219 return STATUS_CODE_INVALID_PARAMETER;
222 m_featureInfoList.reserve(muValues.size());
224 for (std::size_t i = 0; i < muValues.size(); ++i)
225 m_featureInfoList.emplace_back(muValues.at(i), sigmaValues.at(i));
227 return STATUS_CODE_SUCCESS;
237 std::vector<double> values;
241 for (
const double &value : values)
242 valuesFeatureVector.emplace_back(value);
244 m_svInfoList.emplace_back(yAlpha, valuesFeatureVector);
245 return STATUS_CODE_SUCCESS;
252 if (!m_isInitialized)
254 std::cout <<
"SupportVectorMachine: could not perform classification because the svm was uninitialized" << std::endl;
258 if (m_svInfoList.empty())
260 std::cout <<
"SupportVectorMachine: could not perform classification because the initialized svm had no support vectors in the model"
266 standardizedFeatures.reserve(m_nFeatures);
268 if (m_standardizeFeatures)
270 for (std::size_t i = 0; i < m_nFeatures; ++i)
271 standardizedFeatures.push_back(m_featureInfoList.at(i).StandardizeParameter(features.at(i).Get()));
274 double classScore(0.);
277 classScore += supportVectorInfo.m_yAlpha * m_kernelFunction(supportVectorInfo.m_supportVector,
278 (m_standardizeFeatures ? standardizedFeatures : features), m_scaleFactor);
281 return classScore + m_bias;
Header file for the lar support vector machine class.
#define PANDORA_THROW_RESULT_IF(StatusCode1, Operator, Command)
#define PANDORA_RETURN_RESULT_IF_AND_IF(StatusCode1, StatusCode2, Operator, Command)
Header file for the xml helper class.
MvaTypes::MvaFeatureVector MvaFeatureVector
KernelType
KernelType enum.
static double LinearKernel(const LArMvaHelper::MvaFeatureVector &supportVector, const LArMvaHelper::MvaFeatureVector &features, const double scaleFactor=1.)
A linear kernel.
SupportVectorMachine()
Default constructor.
StatusCodeException class.
bool LoadFile(TiXmlEncoding encoding=TIXML_DEFAULT_ENCODING)
TiXmlElement * Element() const
TiXmlHandle FirstChild() const
Return a handle to the first child node.
const std::string & ValueStr() const
const TiXmlElement * NextSiblingElement() const
const TiXmlNode * NextSibling(const std::string &_value) const
STL std::string form.
static StatusCode ReadVectorOfValues(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, std::vector< T > &vector)
Read a vector of values from a (space separated) list in an xml element.
static StatusCode ReadValue(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, T &t)
Read a value from an xml element.
StatusCode
The StatusCode enum.