Skip to content

Commit 01a6189

Browse files
committed
Remove different smoother options for simplicity
1 parent e643548 commit 01a6189

9 files changed

Lines changed: 101 additions & 339 deletions

File tree

core/include/traccc/finding/actors/measurement_kalman_updater.hpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,24 @@ struct measurement_updater : detray::base_actor {
4848
/// Contains the current track states and some statistics
4949
struct state {
5050
using scalar_t = detray::dscalar<algebra_t>;
51+
using track_state_cand_t = filtered_track_state_candidate<algebra_t>;
5152

5253
constexpr state() = default;
5354

5455
TRACCC_HOST_DEVICE
5556
explicit state(
5657
typename edm::measurement_collection::const_device measurements,
5758
vecmem::data::vector_view<unsigned int> meas_ranges_view,
58-
void* track_state_candidates, const smoother_type smoother)
59-
: m_cand_ptr{track_state_candidates},
59+
vecmem::device_vector<track_state_cand_t>::iterator track_state_itr)
60+
: m_track_state_itr{track_state_itr},
6061
m_measurements{measurements},
61-
m_measurement_ranges{meas_ranges_view},
62-
m_run_smoother{smoother} {}
62+
m_measurement_ranges{meas_ranges_view} {}
6363

6464
/// Calibration configuration
6565
traccc::measurement_selector::config m_calib_cfg{};
6666

6767
/// Track states for the current track
68-
void* m_cand_ptr{nullptr};
68+
vecmem::device_vector<track_state_cand_t>::iterator m_track_state_itr{};
6969

7070
/// Statistics for the current track
7171
track_stats<scalar_t> m_stats{};
@@ -85,9 +85,6 @@ struct measurement_updater : detray::base_actor {
8585
/// Per surface measurement index ranges into measurement cont.
8686
vecmem::device_vector<unsigned int> m_measurement_ranges{
8787
vecmem::data::vector_view<unsigned int>{}};
88-
89-
/// The track candidate collection type pointed to
90-
smoother_type m_run_smoother{smoother_type::e_mbf};
9188
};
9289

9390
/// Select the optimal next measurement and run the KF update
@@ -98,9 +95,6 @@ struct measurement_updater : detray::base_actor {
9895
transporter_result) const {
9996
using scalar_t = detray::dscalar<algebra_t>;
10097

101-
// Check for a valid measurement updater state
102-
assert(updater_state.m_cand_ptr != nullptr);
103-
10498
auto& navigation = propagation.navigation();
10599
auto& stepping = propagation.stepping();
106100

@@ -211,12 +205,9 @@ struct measurement_updater : detray::base_actor {
211205
TRACCC_VERBOSE_HOST_DEVICE("Assigned measurement: %d",
212206
cand.meas_idx);
213207

214-
// TODO: Get host-device compatible visitor implementation
215-
const auto i{
216-
static_cast<int>(updater_state.m_stats.n_track_states)};
217-
traccc::make_track_state_candidate(updater_state.m_cand_ptr,
218-
updater_state.m_run_smoother, i,
219-
cand, bound_param);
208+
*updater_state.m_track_state_itr = {cand.meas_idx, cand.chi2,
209+
bound_param};
210+
updater_state.m_track_state_itr++;
220211

221212
// Update statistics
222213
updater_state.m_stats.ndf_sum +=

core/include/traccc/finding/details/progressive_kalman_filter.hpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ progressive_kalman_filter(
5555
const typename edm::measurement_collection::const_view& measurements_view,
5656
const vecmem::data::vector_view<unsigned int> measurement_ranges_view,
5757
const bound_track_parameters<typename detector_t::algebra_type>& seed,
58-
const unsigned int seed_idx, void* track_state_candidate_ptr,
58+
const unsigned int seed_idx,
59+
vecmem::data::vector_view<
60+
filtered_track_state_candidate<typename detector_t::algebra_type>>
61+
track_state_cand_view,
5962
vecmem::data::jagged_vector_view<typename detector_t::surface_type>
6063
surfaces_view,
6164
const finding_config& cfg) {
@@ -70,6 +73,10 @@ progressive_kalman_filter(
7073
typename edm::measurement_collection::const_device measurements{
7174
measurements_view};
7275

76+
// Create the candidate track states container
77+
vecmem::device_vector<filtered_track_state_candidate<algebra_t>>
78+
track_state_cands{track_state_cand_view};
79+
7380
// Create detray propagator
7481
auto prop_cfg{cfg.propagation};
7582
prop_cfg.navigation.estimate_scattering_noise = false;
@@ -109,9 +116,12 @@ progressive_kalman_filter(
109116
typename detray::actor::pointwise_material_interactor<algebra_t>::state
110117
interactor_state;
111118
// Do the measurement selection
119+
const auto cand_offset{
120+
static_cast<int>(seed_idx * cfg.max_track_candidates_per_track)};
121+
auto track_state_itr =
122+
detray::ranges::detail::next(track_state_cands.begin(), cand_offset);
112123
typename traccc::measurement_updater<algebra_t>::state meas_updater_state{
113-
measurements, measurement_ranges_view, track_state_candidate_ptr,
114-
cfg.run_smoother};
124+
measurements, measurement_ranges_view, track_state_itr};
115125
// Collect the surface geometry identifiers for the Kalman smoother
116126
typename detray::actor::surface_sequencer<
117127
typename detector_t::surface_type>::state sequencer_state{

core/include/traccc/finding/details/run_progressive_kalman_filter.hpp

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -125,52 +125,29 @@ run_progressive_kalman_filter(
125125
// is resizable)
126126
vecmem::copy copy{};
127127
typename edm::track_state_collection<algebra_t>::buffer track_states_buffer{
128-
cfg.run_smoother == smoother_type::e_none ? 0u : n_max_states, mr,
129-
vecmem::data::buffer_type::resizable};
128+
n_max_states, mr, vecmem::data::buffer_type::resizable};
130129
copy.setup(track_states_buffer)->ignore();
131130

132131
// Track data collected by the measurement updater during pattern recog.
133-
vecmem::vector<track_state_candidate> track_cands{};
134132
vecmem::vector<filtered_track_state_candidate<algebra_t>>
135-
filtered_track_cands{};
136-
vecmem::vector<full_track_state_candidate<algebra_t>> full_track_cands{};
137-
138-
switch (cfg.run_smoother) {
139-
case smoother_type::e_none: {
140-
track_cands.resize(n_max_states);
141-
break;
142-
}
143-
case smoother_type::e_kalman: {
144-
filtered_track_cands.resize(n_max_states);
145-
break;
146-
}
147-
case smoother_type::e_mbf: {
148-
full_track_cands.resize(n_max_states);
149-
break;
150-
}
151-
default: {
152-
TRACCC_FATAL_HOST("Unknown smoother option");
153-
return track_container;
154-
}
155-
}
133+
track_state_cands{};
134+
track_state_cands.resize(n_max_states);
156135

157136
// Setup the surface sequence buffer
158137
vecmem::data::jagged_vector_buffer<typename detector_t::surface_type>
159138
sf_sequences_buffer{std::vector<unsigned int>{0u}, mr, &mr,
160139
vecmem::data::buffer_type::resizable};
161140

162-
if (cfg.run_smoother == smoother_type::e_kalman) {
163-
const unsigned int n_surfaces_per_track{
164-
std::max(cfg.max_track_candidates_per_track *
165-
cfg.kalman_smoother.surface_sequence_size_factor,
166-
cfg.kalman_smoother.min_surface_sequence_capacity)};
167-
std::vector<unsigned int> seqs_sizes(n_seeds, n_surfaces_per_track);
168-
169-
sf_sequences_buffer = vecmem::data::jagged_vector_buffer<
170-
typename detector_t::surface_type>{
141+
const unsigned int n_surfaces_per_track{
142+
std::max(cfg.max_track_candidates_per_track *
143+
cfg.kalman_smoother.surface_sequence_size_factor,
144+
cfg.kalman_smoother.min_surface_sequence_capacity)};
145+
std::vector<unsigned int> seqs_sizes(n_seeds, n_surfaces_per_track);
146+
147+
sf_sequences_buffer =
148+
vecmem::data::jagged_vector_buffer<typename detector_t::surface_type>{
171149
seqs_sizes, mr, &mr, vecmem::data::buffer_type::resizable};
172-
copy.setup(sf_sequences_buffer)->ignore();
173-
}
150+
copy.setup(sf_sequences_buffer)->ignore();
174151

175152
for (unsigned int seed_idx = 0u; seed_idx < seeds.size(); ++seed_idx) {
176153
const auto& seed = seeds[seed_idx];
@@ -179,20 +156,12 @@ run_progressive_kalman_filter(
179156
// Add the information also to the clog stream
180157
TRACCC_VERBOSE_HOST("Seed: " << seed_idx);
181158

182-
// Set the data pointer to the beginning of the range of the track
183-
const auto cand_offset{static_cast<unsigned int>(
184-
seed_idx * cfg.max_track_candidates_per_track)};
185-
186-
track_state_candidate_data<algebra_t> candidate_data(
187-
cfg.run_smoother, cand_offset, vecmem::get_data(track_cands),
188-
vecmem::get_data(filtered_track_cands),
189-
vecmem::get_data(full_track_cands));
190-
191159
// Run the progressive filter for this seed
192160
const track_stats<scalar_t> trk_stats =
193161
traccc::details::progressive_kalman_filter(
194162
det, field, measurements_view, vecmem::get_data(meas_ranges),
195-
seed, seed_idx, candidate_data.ptr(), sf_sequences_buffer, cfg);
163+
seed, seed_idx, vecmem::get_data(track_state_cands),
164+
sf_sequences_buffer, cfg);
196165

197166
// Check track stats and build the new track object
198167
const unsigned int n_track_states{trk_stats.n_track_states};
@@ -208,34 +177,30 @@ run_progressive_kalman_filter(
208177
edm::track track =
209178
track_container.tracks.at(track_container.tracks.size() - 1u);
210179

211-
track.fit_outcome() = cfg.run_smoother == smoother_type::e_none
212-
? track_fit_outcome::UNKNOWN
213-
: track_fit_outcome::SUCCESS;
214-
180+
track.fit_outcome() = track_fit_outcome::SUCCESS;
215181
track.params() = seed;
216182
track.ndf() = static_cast<scalar_t>(ndf_sum);
217183
track.chi2() = trk_stats.chi2_sum;
218184
track.pval() = prob(trk_stats.chi2_sum, static_cast<scalar_t>(ndf_sum));
219185
track.nholes() = static_cast<unsigned int>(trk_stats.n_holes);
220186
track.constituent_links().resize(n_track_states);
221187

222-
const auto track_state_offset{
223-
static_cast<unsigned int>(track_container.states.size())};
188+
const auto cand_offset{seed_idx * cfg.max_track_candidates_per_track};
224189

225190
// Generate the track states for this track
226-
for (unsigned int state_idx = track_state_offset;
227-
state_idx < track_state_offset + n_track_states; state_idx++) {
228-
229-
const unsigned int link_idx{state_idx - track_state_offset};
191+
for (unsigned int link_idx = 0u; link_idx < n_track_states;
192+
link_idx++) {
230193

231194
TRACCC_DEBUG_HOST("Adding track state (local idx "
232-
<< link_idx << ", global idx " << state_idx
195+
<< link_idx << ", global idx "
196+
<< link_idx + static_cast<unsigned int>(
197+
track_container.states.size())
233198
<< ")");
234199

235200
// Intermediate type required to build a view
236201
traccc::track_state_from_candidate<algebra_t>(
237-
candidate_data.ptr(), cfg.run_smoother, link_idx, measurements,
238-
track, track_states_buffer);
202+
track_state_cands.at(cand_offset + link_idx), link_idx,
203+
measurements, track, track_states_buffer);
239204
}
240205

241206
TRACCC_DEBUG_HOST("Added track " << track_container.tracks.size() - 1

0 commit comments

Comments
 (0)