Skip to content

Commit b72aed1

Browse files
ikantakIsabel Kantak
andauthored
[PWGEM] Add ML-based photon cuts (#14479)
Co-authored-by: Isabel Kantak <[email protected]>
1 parent a064083 commit b72aed1

File tree

10 files changed

+758
-47
lines changed

10 files changed

+758
-47
lines changed

PWGEM/PhotonMeson/Core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ o2physics_add_library(PWGEMPhotonMesonCore
1919
CutsLibrary.cxx
2020
HistogramsLibrary.cxx
2121
EMBitFlags.cxx
22-
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::MLCore O2Physics::PWGEMDileptonCore)
22+
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::MLCore O2Physics::PWGEMDileptonCore KFParticle::KFParticle)
2323

2424
o2physics_target_root_dictionary(PWGEMPhotonMesonCore
2525
HEADERS V0PhotonCut.h
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
// \file EmMlResponse.h
13+
// \brief Class to compute the ML response for EM-analysis selections
14+
// \author Isabel Kantak <[email protected]>, University of Heidelberg
15+
16+
#ifndef PWGEM_PHOTONMESON_CORE_EMMLRESPONSE_H_
17+
#define PWGEM_PHOTONMESON_CORE_EMMLRESPONSE_H_
18+
19+
#include "Tools/ML/MlResponse.h"
20+
21+
namespace o2::analysis
22+
{
23+
24+
template <typename TypeOutputScore = float>
25+
class EmMlResponse : public MlResponse<TypeOutputScore>
26+
{
27+
public:
28+
/// Default constructor
29+
EmMlResponse() = default;
30+
/// Default destructor
31+
virtual ~EmMlResponse() = default;
32+
};
33+
34+
} // namespace o2::analysis
35+
36+
#endif // PWGEM_PHOTONMESON_CORE_EMMLRESPONSE_H_
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file EmMLResponsePCM.h
13+
/// \brief Class to compute the ML response for PCM analysis selections
14+
/// \author Isabel Kantak <[email protected]>, University of Heidelberg
15+
16+
#ifndef PWGEM_PHOTONMESON_CORE_EMMLRESPONSEPCM_H_
17+
#define PWGEM_PHOTONMESON_CORE_EMMLRESPONSEPCM_H_
18+
19+
#include "PWGEM/PhotonMeson/Core/EmMlResponse.h"
20+
21+
#include "Tools/ML/MlResponse.h"
22+
23+
#include <cstdint>
24+
#include <vector>
25+
26+
// Fill the map of available input features
27+
// the key is the feature's name (std::string)
28+
// the value is the corresponding value in EnumInputFeatures
29+
#define FILL_MAP_PCM(FEATURE) \
30+
{ \
31+
#FEATURE, static_cast<uint8_t>(InputFeaturesPCM::FEATURE) \
32+
}
33+
34+
// Check if the index of mCachedIndices (index associated to a FEATURE)
35+
// matches the entry in EnumInputFeatures associated to this FEATURE
36+
// if so, the inputFeatures vector is filled with the FEATURE's value
37+
// by calling the corresponding GETTER from OBJECT
38+
#define CHECK_AND_FILL_VEC_PCM_FULL(OBJECT, FEATURE, GETTER) \
39+
case static_cast<uint8_t>(InputFeaturesPCM::FEATURE): { \
40+
inputFeatures.emplace_back(OBJECT.GETTER()); \
41+
break; \
42+
}
43+
44+
// Specific case of CHECK_AND_FILL_VEC_PCM_FULL(OBJECT, FEATURE, GETTER)
45+
// where OBJECT is named candidate and FEATURE = GETTER
46+
#define CHECK_AND_FILL_VEC_PCM(GETTER) \
47+
case static_cast<uint8_t>(InputFeaturesPCM::GETTER): { \
48+
inputFeatures.emplace_back(candidate.GETTER()); \
49+
break; \
50+
}
51+
52+
namespace o2::analysis
53+
{
54+
55+
enum class InputFeaturesPCM : uint8_t {
56+
v0PhotonCandidatefDCAxyToPV,
57+
v0PhotonCandidatefDCAzToPV,
58+
v0PhotonCandidatefPCA,
59+
v0PhotonCandidatefAlpha,
60+
v0PhotonCandidatefQtArm,
61+
v0PhotonCandidatefChiSquareNDF,
62+
v0PhotonCandidatefCosPA,
63+
posV0LegfTPCNSigmaEl,
64+
posV0LegfTPCNSigmaPi,
65+
negV0LegfTPCNSigmaEl,
66+
negV0LegfTPCNSigmaPi
67+
};
68+
69+
template <typename TypeOutputScore = float>
70+
class EmMlResponsePCM : public EmMlResponse<TypeOutputScore>
71+
{
72+
public:
73+
/// Default constructor
74+
EmMlResponsePCM() = default;
75+
/// Default destructor
76+
virtual ~EmMlResponsePCM() = default;
77+
78+
/// Method to get the input features vector needed for ML inference
79+
/// \param candidate is the V0photon candidate
80+
/// \return inputFeatures vector
81+
template <typename T1, typename T2>
82+
std::vector<float> getInputFeatures(T1 const& candidate, T2 const& posLeg, T2 const& negLeg)
83+
{
84+
std::vector<float> inputFeatures;
85+
86+
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
87+
switch (idx) {
88+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefDCAxyToPV, GetDcaXYToPV);
89+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefDCAzToPV, GetDcaZToPV);
90+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefPCA, GetPCA);
91+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefAlpha, GetAlpha);
92+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefQtArm, GetQt);
93+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefChiSquareNDF, GetChi2NDF);
94+
CHECK_AND_FILL_VEC_PCM_FULL(candidate, v0PhotonCandidatefCosPA, GetCosPA);
95+
CHECK_AND_FILL_VEC_PCM_FULL(posLeg, posV0LegfTPCNSigmaEl, tpcNSigmaEl);
96+
CHECK_AND_FILL_VEC_PCM_FULL(posLeg, posV0LegfTPCNSigmaPi, tpcNSigmaPi);
97+
CHECK_AND_FILL_VEC_PCM_FULL(negLeg, negV0LegfTPCNSigmaEl, tpcNSigmaEl);
98+
CHECK_AND_FILL_VEC_PCM_FULL(negLeg, negV0LegfTPCNSigmaPi, tpcNSigmaPi);
99+
}
100+
}
101+
return inputFeatures;
102+
}
103+
104+
protected:
105+
/// Method to fill the map of available input features
106+
void setAvailableInputFeatures()
107+
{
108+
MlResponse<TypeOutputScore>::mAvailableInputFeatures = {
109+
FILL_MAP_PCM(v0PhotonCandidatefDCAxyToPV),
110+
FILL_MAP_PCM(v0PhotonCandidatefDCAzToPV),
111+
FILL_MAP_PCM(v0PhotonCandidatefPCA),
112+
FILL_MAP_PCM(v0PhotonCandidatefAlpha),
113+
FILL_MAP_PCM(v0PhotonCandidatefQtArm),
114+
FILL_MAP_PCM(v0PhotonCandidatefChiSquareNDF),
115+
FILL_MAP_PCM(v0PhotonCandidatefCosPA),
116+
FILL_MAP_PCM(posV0LegfTPCNSigmaEl),
117+
FILL_MAP_PCM(posV0LegfTPCNSigmaPi),
118+
FILL_MAP_PCM(negV0LegfTPCNSigmaEl),
119+
FILL_MAP_PCM(negV0LegfTPCNSigmaPi)};
120+
}
121+
};
122+
123+
} // namespace o2::analysis
124+
125+
#undef FILL_MAP_PCM
126+
#undef CHECK_AND_FILL_VEC_PCM_FULL
127+
#undef CHECK_AND_FILL_VEC_PCM
128+
129+
#endif // PWGEM_PHOTONMESON_CORE_EMMLRESPONSEPCM_H_

PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include <Framework/AnalysisDataModel.h>
4646
#include <Framework/AnalysisHelpers.h>
4747
#include <Framework/AnalysisTask.h>
48+
#include <Framework/Array2D.h>
4849
#include <Framework/Configurable.h>
4950
#include <Framework/HistogramRegistry.h>
5051
#include <Framework/HistogramSpec.h>
@@ -140,6 +141,19 @@ struct Pi0EtaToGammaGamma {
140141
o2::framework::Configurable<float> cfg_max_TPCNsigmaEl{"cfg_max_TPCNsigmaEl", +3.0, "max. TPC n sigma for electron"};
141142
o2::framework::Configurable<bool> cfg_disable_itsonly_track{"cfg_disable_itsonly_track", false, "flag to disable ITSonly tracks"};
142143
o2::framework::Configurable<bool> cfg_disable_tpconly_track{"cfg_disable_tpconly_track", false, "flag to disable TPConly tracks"};
144+
145+
o2::framework::Configurable<bool> cfg_apply_ml_cuts{"cfg_apply_ml", false, "flag to apply ML cut"};
146+
o2::framework::Configurable<bool> cfg_use_2d_binning{"cfg_use_2d_binning", true, "flag to use 2D binning (pT, cent)"};
147+
o2::framework::Configurable<bool> cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"};
148+
o2::framework::Configurable<int> cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"};
149+
o2::framework::Configurable<int> cfg_nclasses_ml{"cfg_nclasses_ml", static_cast<int>(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"};
150+
o2::framework::Configurable<std::vector<int>> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"};
151+
o2::framework::Configurable<std::vector<std::string>> cfg_input_feature_names{"cfg_input_feature_names", std::vector<std::string>{"feature1", "feature2"}, "input feature names for ML models"};
152+
o2::framework::Configurable<std::vector<std::string>> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector<std::string>{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"};
153+
o2::framework::Configurable<std::vector<std::string>> cfg_onnx_file_names{"cfg_onnx_file_names", std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}, "ONNX file names for ML models"};
154+
o2::framework::Configurable<std::vector<double>> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector<double>{o2::analysis::em_cuts_ml::vecBinsPt}, "pT bins for ML"};
155+
o2::framework::Configurable<std::vector<double>> cfg_bins_cent_ml{"cfg_bins_cent_ml", std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}, "centrality bins for ML"};
156+
o2::framework::Configurable<o2::framework::LabeledArray<double>> cfg_cuts_pcm_ml{"cfg_cuts_pcm_ml", {o2::analysis::em_cuts_ml::Cuts[0], o2::analysis::em_cuts_ml::NBinsPt, o2::analysis::em_cuts_ml::NCutScores, o2::analysis::em_cuts_ml::labelsPt, o2::analysis::em_cuts_ml::labelsCutScore}, "ML selections per pT bin"};
143157
} pcmcuts;
144158

145159
DalitzEECut fDileptonCut;
@@ -482,6 +496,25 @@ struct Pi0EtaToGammaGamma {
482496
fV0PhotonCut.SetRequireITSTPC(pcmcuts.cfg_require_v0_with_itstpc);
483497
fV0PhotonCut.SetRequireITSonly(pcmcuts.cfg_require_v0_with_itsonly);
484498
fV0PhotonCut.SetRequireTPConly(pcmcuts.cfg_require_v0_with_tpconly);
499+
500+
// for ML
501+
fV0PhotonCut.SetApplyMlCuts(pcmcuts.cfg_apply_ml_cuts);
502+
fV0PhotonCut.SetUse2DBinning(pcmcuts.cfg_use_2d_binning);
503+
fV0PhotonCut.SetLoadMlModelsFromCCDB(pcmcuts.cfg_load_ml_models_from_ccdb);
504+
fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml);
505+
fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb);
506+
fV0PhotonCut.SetCcdbUrl(ccdburl);
507+
fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml);
508+
fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb);
509+
fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names);
510+
fV0PhotonCut.SetBinsPtMl(pcmcuts.cfg_bins_pt_ml);
511+
fV0PhotonCut.SetBinsCentMl(pcmcuts.cfg_bins_cent_ml);
512+
fV0PhotonCut.SetCutsPCMMl(pcmcuts.cfg_cuts_pcm_ml);
513+
fV0PhotonCut.SetNamesInputFeatures(pcmcuts.cfg_input_feature_names);
514+
515+
if (pcmcuts.cfg_apply_ml_cuts) {
516+
fV0PhotonCut.initV0MlModels(ccdbApi);
517+
}
485518
}
486519

487520
void DefineDileptonCut()
@@ -662,6 +695,8 @@ struct Pi0EtaToGammaGamma {
662695
{
663696
for (const auto& collision : collisions) {
664697
initCCDB(collision);
698+
fV0PhotonCut.SetCentrality(collision.centFT0M());
699+
fV0PhotonCut.SetD_Bz(d_bz);
665700
int ndiphoton = 0;
666701
if ((pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPHOSPHOS || pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPCMPHOS) && !collision.alias_bit(triggerAliases::kTVXinPHOS)) {
667702
continue;

PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <Framework/AnalysisDataModel.h>
4343
#include <Framework/AnalysisHelpers.h>
4444
#include <Framework/AnalysisTask.h>
45+
#include <Framework/Array2D.h>
4546
#include <Framework/Configurable.h>
4647
#include <Framework/HistogramRegistry.h>
4748
#include <Framework/HistogramSpec.h>
@@ -129,6 +130,19 @@ struct Pi0EtaToGammaGammaMC {
129130
o2::framework::Configurable<float> cfg_max_TPCNsigmaEl{"cfg_max_TPCNsigmaEl", +3.0, "max. TPC n sigma for electron"};
130131
o2::framework::Configurable<bool> cfg_disable_itsonly_track{"cfg_disable_itsonly_track", false, "flag to disable ITSonly tracks"};
131132
o2::framework::Configurable<bool> cfg_disable_tpconly_track{"cfg_disable_tpconly_track", false, "flag to disable TPConly tracks"};
133+
134+
o2::framework::Configurable<bool> cfg_apply_ml_cuts{"cfg_apply_ml", false, "flag to apply ML cut"};
135+
o2::framework::Configurable<bool> cfg_use_2d_binning{"cfg_use_2d_binning", true, "flag to use 2D binning (pT, cent)"};
136+
o2::framework::Configurable<bool> cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"};
137+
o2::framework::Configurable<int> cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"};
138+
o2::framework::Configurable<int> cfg_nclasses_ml{"cfg_nclasses_ml", static_cast<int>(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"};
139+
o2::framework::Configurable<std::vector<int>> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"};
140+
o2::framework::Configurable<std::vector<std::string>> cfg_input_feature_names{"cfg_input_feature_names", std::vector<std::string>{"feature1", "feature2"}, "input feature names for ML models"};
141+
o2::framework::Configurable<std::vector<std::string>> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector<std::string>{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"};
142+
o2::framework::Configurable<std::vector<std::string>> cfg_onnx_file_names{"cfg_onnx_file_names", std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}, "ONNX file names for ML models"};
143+
o2::framework::Configurable<std::vector<double>> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector<double>{o2::analysis::em_cuts_ml::vecBinsPt}, "pT bins for ML"};
144+
o2::framework::Configurable<std::vector<double>> cfg_bins_cent_ml{"cfg_bins_cent_ml", std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}, "centrality bins for ML"};
145+
o2::framework::Configurable<o2::framework::LabeledArray<double>> cfg_cuts_pcm_ml{"cfg_cuts_pcm_ml", {o2::analysis::em_cuts_ml::Cuts[0], o2::analysis::em_cuts_ml::NBinsPt, o2::analysis::em_cuts_ml::NCutScores, o2::analysis::em_cuts_ml::labelsPt, o2::analysis::em_cuts_ml::labelsCutScore}, "ML selections per pT bin"};
132146
} pcmcuts;
133147

134148
DalitzEECut fDileptonCut;
@@ -322,6 +336,25 @@ struct Pi0EtaToGammaGammaMC {
322336
fV0PhotonCut.SetRequireITSTPC(pcmcuts.cfg_require_v0_with_itstpc);
323337
fV0PhotonCut.SetRequireITSonly(pcmcuts.cfg_require_v0_with_itsonly);
324338
fV0PhotonCut.SetRequireTPConly(pcmcuts.cfg_require_v0_with_tpconly);
339+
340+
// for ML
341+
fV0PhotonCut.SetApplyMlCuts(pcmcuts.cfg_apply_ml_cuts);
342+
fV0PhotonCut.SetUse2DBinning(pcmcuts.cfg_use_2d_binning);
343+
fV0PhotonCut.SetLoadMlModelsFromCCDB(pcmcuts.cfg_load_ml_models_from_ccdb);
344+
fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml);
345+
fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb);
346+
fV0PhotonCut.SetCcdbUrl(ccdburl);
347+
fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml);
348+
fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb);
349+
fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names);
350+
fV0PhotonCut.SetBinsPtMl(pcmcuts.cfg_bins_pt_ml);
351+
fV0PhotonCut.SetBinsCentMl(pcmcuts.cfg_bins_cent_ml);
352+
fV0PhotonCut.SetCutsPCMMl(pcmcuts.cfg_cuts_pcm_ml);
353+
fV0PhotonCut.SetNamesInputFeatures(pcmcuts.cfg_input_feature_names);
354+
355+
if (pcmcuts.cfg_apply_ml_cuts) {
356+
fV0PhotonCut.initV0MlModels(ccdbApi);
357+
}
325358
}
326359

327360
void DefineDileptonCut()
@@ -520,7 +553,8 @@ struct Pi0EtaToGammaGammaMC {
520553
{
521554
for (auto& collision : collisions) {
522555
initCCDB(collision);
523-
556+
fV0PhotonCut.SetCentrality(collision.centFT0M());
557+
fV0PhotonCut.SetD_Bz(d_bz);
524558
if ((pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPHOSPHOS || pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPCMPHOS) && !collision.alias_bit(triggerAliases::kTVXinPHOS)) {
525559
continue;
526560
}

0 commit comments

Comments
 (0)