11#include <torch/script.h>
12#include <torch/torch.h>
37 m_useTrainingMode(false),
38 m_trainingOutputFile(
"")
62 const int SHOWER{1}, TRACK{2};
70 const HitType view{pCaloHitList->front()->GetHitType()};
73 return STATUS_CODE_NOT_ALLOWED;
78 trainingOutputFileName +=
"_CaloHitListU.csv";
80 trainingOutputFileName +=
"_CaloHitListV.csv";
82 trainingOutputFileName +=
"_CaloHitListW.csv";
94 for (
const CaloHit *pCaloHit : *pCaloHitList)
97 float inputEnergy{0.f};
103 if (targetMCParticleToHitsMap.find(pMCParticle) == targetMCParticleToHitsMap.end())
107 inputEnergy = pCaloHit->GetInputEnergy();
108 if (inputEnergy < 0.f)
112 if (pdg == 11 || pdg == 22)
122 featureVector.push_back(
static_cast<double>(pCaloHit->GetPositionVector().GetX()));
123 featureVector.push_back(
static_cast<double>(pCaloHit->GetPositionVector().GetZ()));
124 featureVector.push_back(
static_cast<double>(tag));
125 featureVector.push_back(
static_cast<double>(inputEnergy));
128 featureVector.push_back(
static_cast<double>(featureVector.size() / 4));
129 std::rotate(featureVector.rbegin(), featureVector.rbegin() + 1, featureVector.rend());
134 return STATUS_CODE_SUCCESS;
141 const float eps{1.1920929e-7};
153 const HitType view{pCaloHitList->front()->GetHitType()};
156 return STATUS_CODE_NOT_ALLOWED;
165 this->
GetHitRegion(*pCaloHitList, xMin, xMax, zMin, zMax);
166 const float xRange = (xMax + eps) - (xMin - eps);
167 int nTilesX =
static_cast<int>(std::ceil(xRange /
m_tileSize));
171 const int nTiles = sparseMap.size();
179 for (
int i = 0; i < nTiles; ++i)
181 for (
const CaloHit *pCaloHit : *pCaloHitList)
183 const float x(pCaloHit->GetPositionVector().GetX());
184 const float z(pCaloHit->GetPositionVector().GetZ());
186 const int tileX =
static_cast<int>(std::floor((x - xMin) /
m_tileSize));
187 const int tileZ =
static_cast<int>(std::floor((z - zMin) /
m_tileSize));
188 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
192 const float localX = std::fmod(x - xMin,
m_tileSize);
193 const float localZ = std::fmod(z - zMin,
m_tileSize);
197 weights[pixelZ][pixelX] += pCaloHit->GetInputEnergy();
202 float chargeMin{std::numeric_limits<float>::max()}, chargeMax{-std::numeric_limits<float>::max()};
207 if (weights[r][c] > chargeMax)
208 chargeMax = weights[r][c];
209 if (weights[r][c] < chargeMin)
210 chargeMin = weights[r][c];
213 float chargeRange{chargeMax - chargeMin};
214 if (chargeRange <= 0.f)
221 auto accessor = input.accessor<float, 4>();
222 for (
const CaloHit *pCaloHit : *pCaloHitList)
224 const float x(pCaloHit->GetPositionVector().GetX());
225 const float z(pCaloHit->GetPositionVector().GetZ());
227 const int tileX =
static_cast<int>(std::floor((x - xMin) /
m_tileSize));
228 const int tileZ =
static_cast<int>(std::floor((z - zMin) /
m_tileSize));
229 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
233 const float localX = std::fmod(x - xMin,
m_tileSize);
234 const float localZ = std::fmod(z - zMin,
m_tileSize);
238 accessor[0][0][pixelZ][pixelX] = (weights[pixelZ][pixelX] - chargeMin) / chargeRange;
239 caloHitToPixelMap.insert(std::make_pair(pCaloHit, std::make_tuple(tileZ, tileX, pixelZ, pixelX)));
249 inputs.push_back(input);
252 auto outputAccessor = output.accessor<float, 4>();
254 for (
const CaloHit *pCaloHit : *pCaloHitList)
256 auto found{caloHitToPixelMap.find(pCaloHit)};
257 if (found == caloHitToPixelMap.end())
259 auto pixelMap = found->second;
260 const int tileZ(std::get<0>(pixelMap));
261 const int tileX(std::get<1>(pixelMap));
262 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
265 const int pixelZ(std::get<2>(pixelMap));
266 const int pixelX(std::get<3>(pixelMap));
269 float probShower = exp(outputAccessor[0][1][pixelZ][pixelX]);
270 float probTrack = exp(outputAccessor[0][2][pixelZ][pixelX]);
271 float probNull = exp(outputAccessor[0][0][pixelZ][pixelX]);
272 if (probShower > probTrack && probShower > probNull)
273 showerHits.push_back(pCaloHit);
274 else if (probTrack > probShower && probTrack > probNull)
275 trackHits.push_back(pCaloHit);
277 otherHits.push_back(pCaloHit);
278 float recipSum = 1.f / (probShower + probTrack);
280 probShower *= recipSum;
281 probTrack *= recipSum;
284 pLArCaloHit->SetTrackProbability(probTrack);
294 const std::string trackListName(
"TrackHits_" + listName);
295 const std::string showerListName(
"ShowerHits_" + listName);
296 const std::string otherListName(
"OtherHits_" + listName);
308 return STATUS_CODE_SUCCESS;
315 xMin = std::numeric_limits<float>::max();
316 xMax = -std::numeric_limits<float>::max();
317 zMin = std::numeric_limits<float>::max();
318 zMax = -std::numeric_limits<float>::max();
319 for (
const CaloHit *pCaloHit : caloHitList)
321 const float x(pCaloHit->GetPositionVector().GetX());
322 const float z(pCaloHit->GetPositionVector().GetZ());
340 std::map<int, bool> tilePopulationMap;
341 for (
const CaloHit *pCaloHit : caloHitList)
343 const float x(pCaloHit->GetPositionVector().GetX());
344 const float z(pCaloHit->GetPositionVector().GetZ());
346 const int tileX =
static_cast<int>(std::floor((x - xMin) /
m_tileSize));
347 const int tileZ =
static_cast<int>(std::floor((z - zMin) /
m_tileSize));
348 const int tile = tileZ * nTilesX + tileX;
349 tilePopulationMap.insert(std::make_pair(tile,
true));
353 for (
auto element : tilePopulationMap)
357 sparseMap.insert(std::make_pair(element.first, nextTile));
375 bool modelLoaded{
false};
402 std::cout <<
"Error: Inference requested, but no model files were successfully loaded" << std::endl;
403 return STATUS_CODE_INVALID_PARAMETER;
414 std::cout <<
"Error: Invalid image size specification" << std::endl;
415 return STATUS_CODE_INVALID_PARAMETER;
419 return STATUS_CODE_SUCCESS;
Header file for the deep learning track shower id algorithm.
Header file for the lar calo hit class.
Header file for the file helper class.
Header file for the lar monte carlo particle helper helper class.
Header file for the lar monitoring helper helper class.
Header file for the pfo helper class.
#define PANDORA_RETURN_RESULT_IF_AND_IF(StatusCode1, StatusCode2, Operator, Command)
#define PANDORA_RETURN_RESULT_IF(StatusCode1, Operator, Command)
static pandora::StatusCode GetCurrentList(const pandora::Algorithm &algorithm, const T *&pT)
Get the current list.
static pandora::StatusCode GetList(const pandora::Algorithm &algorithm, const std::string &listName, const T *&pT)
Get a named list.
void SetShowerProbability(const float probability)
Set the probability that the hit is shower-like.
static std::string FindFileInPath(const std::string &unqualifiedFileName, const std::string &environmentVariable, const std::string &delimiter=":")
Find the fully-qualified file name by searching through a list of delimiter-separated paths in a name...
float m_maxPhotonPropagation
the maximum photon propagation length
unsigned int m_minHitsForGoodView
the minimum number of Hits for a good view
static void SelectReconstructableMCParticles(const pandora::MCParticleList *pMCParticleList, const pandora::CaloHitList *pCaloHitList, const PrimaryParameters ¶meters, std::function< bool(const pandora::MCParticle *const)> fCriteria, MCContributionMap &selectedMCParticlesToHitsMap)
Select target, reconstructable mc particles that match given criteria.
std::unordered_map< const pandora::MCParticle *, pandora::CaloHitList > MCContributionMap
static bool IsDescendentOf(const pandora::MCParticle *const pMCParticle, const int pdg, const bool isChargeSensitive=false)
Determine if the MC particle is a descendent of a particle with the given PDG code.
static bool IsBeamNeutrinoFinalState(const pandora::MCParticle *const pMCParticle)
Returns true if passed a primary neutrino final state MCParticle.
MvaTypes::MvaFeatureVector MvaFeatureVector
static pandora::StatusCode ProduceTrainingExample(const std::string &trainingOutputFile, const bool result, TCONTAINER &&featureContainer)
Produce a training example with the given features and result.
std::map< const pandora::CaloHit *, std::tuple< int, int, int, int > > CaloHitToPixelMap
DlHitTrackShowerIdAlgorithm()
Default constructor.
virtual ~DlHitTrackShowerIdAlgorithm()
pandora::StatusCode Run()
Run the algorithm.
std::string m_modelFileNameV
Model file name for V view.
LArDLHelper::TorchModel m_modelU
Model for the U view.
std::string m_modelFileNameW
Model file name for W view.
pandora::StatusCode ReadSettings(const pandora::TiXmlHandle xmlHandle)
Read the algorithm settings.
bool m_useTrainingMode
Training mode.
pandora::StatusCode Train()
Produce files that act as inputs to network training.
std::string m_trainingOutputFile
Output file name for training examples.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax)
Identify the XZ range containing the hits for an event.
pandora::StringVector m_caloHitListNames
Name of input calo hit list.
LArDLHelper::TorchModel m_modelW
Model for the W view.
void GetSparseTileMap(const pandora::CaloHitList &caloHitList, const float xMin, const float zMin, const int nTilesX, PixelToTileMap &sparseMap)
Populate a map between pixels and tiles.
bool m_visualize
Whether to visualize the track shower ID scores.
pandora::StatusCode Infer()
Run network inference.
float m_tileSize
Size of tile in cm.
std::string m_modelFileNameU
Model file name for U view.
LArDLHelper::TorchModel m_modelV
Model for the V view.
int m_imageWidth
Width of images in pixels.
std::map< int, int > PixelToTileMap
int m_imageHeight
Height of images in pixels.
torch::jit::script::Module TorchModel
static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor)
Create a torch input tensor.
std::vector< torch::jit::IValue > TorchInputVector
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output)
Run a deep learning model.
static const MCParticle * GetMainMCParticle(const T *const pT)
Find the mc particle making the largest contribution to a specified calo hit, track or cluster.
int GetParticleId() const
Get the PDG code of the mc particle.
const Pandora & GetPandora() const
Get the associated pandora instance.
StatusCodeException class.
static StatusCode ReadVectorOfValues(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, std::vector< T > &vector)
Read a vector of values from a (space separated) list in an xml element.
static StatusCode ReadValue(const TiXmlHandle &xmlHandle, const std::string &xmlElementName, T &t)
Read a value from an xml element.
HitType
Calorimeter hit type enum.
MANAGED_CONTAINER< const MCParticle * > MCParticleList
MANAGED_CONTAINER< const CaloHit * > CaloHitList
StatusCode
The StatusCode enum.