Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ asr:
fused_batch_size: -1
greedy:
use_cuda_graph_decoder: false # Disabled due to issues with decoding
enable_per_stream_biasing: false # Per-stream biasing in decoder
enable_per_stream_biasing: true # Per-stream biasing in decoder
max_symbols: 10
# n-gram LM
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_confidence_utils,
)
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer
Expand Down Expand Up @@ -306,6 +307,36 @@
eos_flags.append(request.is_last)

previous_hypotheses = [state.get_previous_hypothesis() for state in states]

try:
decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer
biasing_enabled = decoding_computer.per_stream_biasing_enabled
except AttributeError:
decoding_computer = None
biasing_enabled = False

if not biasing_enabled and any(state.has_biasing_request() for state in states):
logging.warning("Biasing request is not empty, but decoder does not support per-stream biasing. Skipping")

# Handle per-stream biasing: add biasing models to multi_model if needed
if biasing_enabled:
for i, (request, state, previous_hyp) in enumerate(zip(requests, states, previous_hypotheses)):
if state.has_biasing_request():
if state.options.biasing_cfg.multi_model_id is None:
if state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.add_to_multi_model(
tokenizer=self.asr_model.tokenizer,
biasing_multi_model=decoding_computer.biasing_multi_model,
)
else:
logging.warning(
"Biasing request is not empty, not auto managed and not compiled. Skipping"
)
if previous_hyp is None:
previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(state.options.biasing_cfg)
else:
previous_hyp.biasing_cfg = state.options.biasing_cfg

context, mapping = self.context_manager.get_context(stream_ids)

prompt_vectors = None
Expand Down Expand Up @@ -344,6 +375,16 @@
state.cleanup_after_eou()
ready_state_ids.add(request.stream_id)

# Cleanup per-stream biasing models when stream ends
if biasing_enabled:
for request, state in zip(requests, states):
# only the first request contains biasing options; biasing options for the stream are stored in state
if request.is_last and state.has_biasing_request():
if state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.remove_from_multi_model(
biasing_multi_model=decoding_computer.biasing_multi_model
)

def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None:
"""
Transcribes the feature buffers in a streaming manner.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def set_previous_hypothesis(self, previous_hypothesis: Hypothesis) -> None:
"""
self.previous_hypothesis = previous_hypothesis

def get_previous_hypothesis(self) -> Hypothesis:
def get_previous_hypothesis(self) -> Hypothesis | None:
"""
Get the previous hypothesis
Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo \
enable_itn=False \
enable_nmt=False \
asr_output_granularity=segment

# Cache-Aware RNN-T model
coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo \
examples/asr/asr_streaming_inference/asr_streaming_infer.py \
--config-path="../conf/asr_streaming_inference/" \
--config-name=cache_aware_rnnt.yaml \
audio_file="/home/TestData/asr/canary/dev-other-wav-10-boost-gt.json" \
output_filename="/tmp/stt_inference_boost_gt_res_ca_rnnt.json" \
asr.model_name="nvidia/nemotron-speech-streaming-en-0.6b" \
streaming.batch_size=5 \
lang=en \
enable_pnc=False \
enable_itn=False \
enable_nmt=False \
asr_output_granularity=segment
Loading