Conversation
This reverts commit 92d88da.
|
|
||
| if (not fwd and threadIdx.x == 0 and t == seq_lens[seq]) { | ||
| prev_states[final_states[seq]] = 0.0; | ||
| for (unsigned fs = final_state_offsets[seq]; fs < final_state_offsets[seq+1]; fs++) { |
There was a problem hiding this comment.
does fs < final_state_offsets[seq+1] work for the last sequence? e.g. are we adding num_final_states to the end of the offsets or are we expecting the user to do it?
There was a problem hiding this comment.
Yes so far I'm expecting the user to do it. I'm not sure if it's the best way though. An alternative would be to provide a num_end_states array, from which we compute the offsets like it's done for the edges.
However, we will have to extend the logic for the edge offset computation at some point if we want to use a single shared automaton for all sequences.
My idea would be to have a [B,2]-array of start- and end-points to read out from the edge-tensor. The user wouldn't know of it though and just have the usual fbw_loss and an additional fbw_loss_shared_automaton or something. I just think it would be the easiest way to reuse the code that @curufinwe wrote for the "shared automaton" case.
There was a problem hiding this comment.
I would try to avoid to rely on the user doing the right thing. And adding len(prev_states) to final_state_offsets if its length is equal to num_seqs seems a simple fix to me.
| unsigned* d_start_states = Ndarray_DEV_DATA_uint32(start_states); | ||
| unsigned* d_end_states = Ndarray_DEV_DATA_uint32(end_states); | ||
| unsigned* d_end_state_offsets = Ndarray_DEV_DATA_uint32(end_state_offsets); | ||
| unsigned* d_seq_lens = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(seq_lens)); |
There was a problem hiding this comment.
why can we remove the reinterpret_cast above but not here?
There was a problem hiding this comment.
I'm not sure but I think we should be able to handle it the way the others are handled.
| state_offsets.data().get(), edge_offsets.data().get(), d_seq_lens, | ||
| d_from, d_to, d_weights, d_emission_idxs, d_start_states, d_end_states, | ||
| state_offsets.data().get(), d_edge_offsets.data().get(), d_seq_lens, | ||
| d_from, d_to, d_weights, d_emission_idxs, d_start_states, d_end_states, d_end_state_offsets, |
There was a problem hiding this comment.
since the new parameter does not work for the v1 baum_welch implementation, should we add some check somewhere that debug_options.explicit_merge is not set when we have multiple ends?
There was a problem hiding this comment.
Yes that's a good idea. Alternatively I could adjust the v2 accordingly.
Co-authored-by: michelwi <michelwi@users.noreply.github.com>
This is one steps towards computing the full "denominator".
It's a draft so far because I'd appreciate some input on the design: Currently I extended the FBW2 Op by two parameters where we recover the "usual" behavior if we pass an empty tensor for the extra parameter
end_state_offsets.Alternatives would be:
fbw2_cudafunction: this creates a lot of code redundancy in my opinionfbw2_cudaand trying to call the more general function from the more specific onestd::optionalbut I'm far from an expert on thatThe same questions can be asked for the python entry point, e.g. if we would like to have a separate
multi_end_fbw2function for the python interface.