diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml index 84730998b21f..2d098db40c01 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -13,7 +13,7 @@ asr: fused_batch_size: -1 greedy: use_cuda_graph_decoder: false # Disabled due to issues with decoding - enable_per_stream_biasing: false # Per-stream biasing in decoder + enable_per_stream_biasing: true # Per-stream biasing in decoder max_symbols: 10 # n-gram LM ngram_lm_model: null # The path to built '.nemo' NGPU-LM model diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 1367f0514247..2b4de83dac1c 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -40,6 +40,7 @@ get_confidence_utils, ) from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.utils import logging if TYPE_CHECKING: from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer @@ -306,6 +307,36 @@ def cache_aware_transcribe_step( eos_flags.append(request.is_last) previous_hypotheses = [state.get_previous_hypothesis() for state in states] + + try: + decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer + biasing_enabled = decoding_computer.per_stream_biasing_enabled + except AttributeError: + decoding_computer = None + biasing_enabled = False + + if not biasing_enabled and any(state.has_biasing_request() for state in states): + logging.warning("Biasing request is not empty, but decoder does not support per-stream biasing. Skipping") + + # Handle per-stream biasing: add biasing models to multi_model if needed + if biasing_enabled: + for i, (request, state, previous_hyp) in enumerate(zip(requests, states, previous_hypotheses)): + if state.has_biasing_request(): + if state.options.biasing_cfg.multi_model_id is None: + if state.options.biasing_cfg.auto_manage_multi_model: + state.options.biasing_cfg.add_to_multi_model( + tokenizer=self.asr_model.tokenizer, + biasing_multi_model=decoding_computer.biasing_multi_model, + ) + else: + logging.warning( + "Biasing request is not empty, not auto managed and not compiled. Skipping" + ) + if previous_hyp is None: + previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(state.options.biasing_cfg) + else: + previous_hyp.biasing_cfg = state.options.biasing_cfg + context, mapping = self.context_manager.get_context(stream_ids) prompt_vectors = None @@ -344,6 +375,16 @@ def cache_aware_transcribe_step( state.cleanup_after_eou() ready_state_ids.add(request.stream_id) + # Cleanup per-stream biasing models when stream ends + if biasing_enabled: + for request, state in zip(requests, states): + # only the first request contains biasing options; biasing options for the stream are stored in state + if request.is_last and state.has_biasing_request(): + if state.options.biasing_cfg.auto_manage_multi_model: + state.options.biasing_cfg.remove_from_multi_model( + biasing_multi_model=decoding_computer.biasing_multi_model + ) + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ Transcribes the feature buffers in a streaming manner. diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index d3efd80e4396..c9374c37ba26 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -51,7 +51,7 @@ def set_previous_hypothesis(self, previous_hypothesis: Hypothesis) -> None: """ self.previous_hypothesis = previous_hypothesis - def get_previous_hypothesis(self) -> Hypothesis: + def get_previous_hypothesis(self) -> Hypothesis | None: """ Get the previous hypothesis Returns: diff --git a/tests/functional_tests/L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT.sh b/tests/functional_tests/L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT.sh index b2c73ed8a34d..77a9e39c673f 100644 --- a/tests/functional_tests/L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT.sh +++ b/tests/functional_tests/L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT.sh @@ -42,3 +42,18 @@ coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo \ enable_itn=False \ enable_nmt=False \ asr_output_granularity=segment + +# Cache-Aware RNN-T model +coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo \ + examples/asr/asr_streaming_inference/asr_streaming_infer.py \ + --config-path="../conf/asr_streaming_inference/" \ + --config-name=cache_aware_rnnt.yaml \ + audio_file="/home/TestData/asr/canary/dev-other-wav-10-boost-gt.json" \ + output_filename="/tmp/stt_inference_boost_gt_res_ca_rnnt.json" \ + asr.model_name="nvidia/nemotron-speech-streaming-en-0.6b" \ + streaming.batch_size=5 \ + lang=en \ + enable_pnc=False \ + enable_itn=False \ + enable_nmt=False \ + asr_output_granularity=segment