Skip to content

Allow Multiple End States in FBW2#13

Draft
DanEnergetics wants to merge 17 commits intomainfrom
fbw2_multi_end
Draft

Allow Multiple End States in FBW2#13
DanEnergetics wants to merge 17 commits intomainfrom
fbw2_multi_end

Conversation

@DanEnergetics
Copy link
Contributor

@DanEnergetics DanEnergetics commented Jun 17, 2025

This is one steps towards computing the full "denominator".
It's a draft so far because I'd appreciate some input on the design: Currently I extended the FBW2 Op by two parameters where we recover the "usual" behavior if we pass an empty tensor for the extra parameter end_state_offsets.
Alternatives would be:

  • copying the fbw2_cuda function: this creates a lot of code redundancy in my opinion
  • overloading fbw2_cuda and trying to call the more general function from the more specific one
  • instead of using empty tensors we could use something like std::optional but I'm far from an expert on that
    The same questions can be asked for the python entry point, e.g. if we would like to have a separate multi_end_fbw2 function for the python interface.


if (not fwd and threadIdx.x == 0 and t == seq_lens[seq]) {
prev_states[final_states[seq]] = 0.0;
for (unsigned fs = final_state_offsets[seq]; fs < final_state_offsets[seq+1]; fs++) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does fs < final_state_offsets[seq+1] work for the last sequence? e.g. are we adding num_final_states to the end of the offsets or are we expecting the user to do it?

Copy link
Contributor Author

@DanEnergetics DanEnergetics Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes so far I'm expecting the user to do it. I'm not sure if it's the best way though. An alternative would be to provide a num_end_states array, from which we compute the offsets like it's done for the edges.
However, we will have to extend the logic for the edge offset computation at some point if we want to use a single shared automaton for all sequences.
My idea would be to have a [B,2]-array of start- and end-points to read out from the edge-tensor. The user wouldn't know of it though and just have the usual fbw_loss and an additional fbw_loss_shared_automaton or something. I just think it would be the easiest way to reuse the code that @curufinwe wrote for the "shared automaton" case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would try to avoid to rely on the user doing the right thing. And adding len(prev_states) to final_state_offsets if its length is equal to num_seqs seems a simple fix to me.

unsigned* d_start_states = Ndarray_DEV_DATA_uint32(start_states);
unsigned* d_end_states = Ndarray_DEV_DATA_uint32(end_states);
unsigned* d_end_state_offsets = Ndarray_DEV_DATA_uint32(end_state_offsets);
unsigned* d_seq_lens = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(seq_lens));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can we remove the reinterpret_cast above but not here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure but I think we should be able to handle it the way the others are handled.

state_offsets.data().get(), edge_offsets.data().get(), d_seq_lens,
d_from, d_to, d_weights, d_emission_idxs, d_start_states, d_end_states,
state_offsets.data().get(), d_edge_offsets.data().get(), d_seq_lens,
d_from, d_to, d_weights, d_emission_idxs, d_start_states, d_end_states, d_end_state_offsets,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the new parameter does not work for the v1 baum_welch implementation, should we add some check somewhere that debug_options.explicit_merge is not set when we have multiple ends?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's a good idea. Alternatively I could adjust the v2 accordingly.

Co-authored-by: michelwi <michelwi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants