From 5a732874d5742ddabe899194c7fae4835a839e9c Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 16:31:16 -0500 Subject: [PATCH 1/8] first attempt -- broken --- asr/wenet/bin/recognize_wav.py | 4 + asr/wenet/cli/reverb.py | 13 ++- asr/wenet/transformer/asr_model.py | 95 ++++++---------------- asr/wenet/transformer/search.py | 124 +++++++++++++++++++---------- 4 files changed, 121 insertions(+), 115 deletions(-) diff --git a/asr/wenet/bin/recognize_wav.py b/asr/wenet/bin/recognize_wav.py index d4feafb..3cf112f 100644 --- a/asr/wenet/bin/recognize_wav.py +++ b/asr/wenet/bin/recognize_wav.py @@ -205,4 +205,8 @@ def main(): if __name__ == "__main__": + import time + start_time = time.time() main() + end_time = time.time() + logging.info(f"Total processing time: {end_time - start_time:.2f} seconds") diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index 25f7c2f..978cff1 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -231,14 +231,23 @@ def transcribe_modes( infos={"tasks": ["transcribe"], "langs": ["en"]}, cat_embs=cat_embs, ) + print(get_output( + "txt", + self.tokenizer, + Path(audio_file).stem, + hyps["attention_rescoring"], + timings_adjustment, + chunk_size, + self.input_frame_length, + self.output_frame_length, + )) results.append(hyps) - outputs = [] for mode in modes: outputs.append(get_output( format, self.tokenizer, - Path(audio_file).name, + Path(audio_file).stem, list(chain(*(hyp[mode] for hyp in results))), timings_adjustment, chunk_size, diff --git a/asr/wenet/transformer/asr_model.py b/asr/wenet/transformer/asr_model.py index ba4936d..8ae6c5d 100644 --- a/asr/wenet/transformer/asr_model.py +++ b/asr/wenet/transformer/asr_model.py @@ -874,30 +874,19 @@ def forward_attention_decoder( cat_embs: Optional[torch.Tensor] = None, verbose: bool = False ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ Export interface for c++ call, forward decoder with multiple - hypothesis from ctc prefix beam search and one encoder output - Args: - hyps (torch.Tensor): hyps from ctc prefix beam search, already - pad sos at the begining - hyps_lens (torch.Tensor): length of each hyp in hyps - encoder_out (torch.Tensor): corresponding encoder output - r_hyps (torch.Tensor): hyps from ctc prefix beam search, already - pad eos at the begining which is used fo right to left decoder - reverse_weight: used for verfing whether used right to left decoder, - > 0 will use. - - Returns: - torch.Tensor: decoder output - """ - assert encoder_out.size(0) == 1 - num_hyps = hyps.size(0) - assert hyps_lens.size(0) == num_hyps - encoder_out = encoder_out.repeat(num_hyps, 1, 1) - encoder_mask = torch.ones(num_hyps, - 1, - encoder_out.size(1), - dtype=torch.bool, - device=encoder_out.device) + batch_size = encoder_out.size(0) + beam_size = hyps.size(0) // batch_size + assert hyps.size(0) == batch_size * beam_size, "Number of hypotheses must be batch_size * beam_size" + assert hyps_lens.size(0) == batch_size * beam_size + + # Repeat encoder output for each hypothesis in the beam, maintaining batch separation + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1) + encoder_out = encoder_out.view(batch_size * beam_size, -1, encoder_out.size(-1)) + encoder_mask = torch.ones(batch_size * beam_size, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=encoder_out.device) # input for right to left decoder # this hyps_lens has count token, we need minus it. @@ -905,54 +894,21 @@ def forward_attention_decoder( # this hyps has included token, so it should be # convert the original hyps. r_hyps = hyps[:, 1:] - # >>> r_hyps - # >>> tensor([[ 1, 2, 3], - # >>> [ 9, 8, 4], - # >>> [ 2, -1, -1]]) - # >>> r_hyps_lens - # >>> tensor([3, 3, 1]) - - # NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used - # in `reverse_pad_list` thus we have to refine the below code. - # Issue: https://github.com/wenet-e2e/wenet/issues/1113 - # Equal to: - # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) - # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + + # Handle right-to-left decoding max_len = torch.max(r_hyps_lens) index_range = torch.arange(0, max_len, 1).to(encoder_out.device) seq_len_expand = r_hyps_lens.unsqueeze(1) - seq_mask = seq_len_expand > index_range # (beam, max_len) - # >>> seq_mask - # >>> tensor([[ True, True, True], - # >>> [ True, True, True], - # >>> [ True, False, False]]) - index = (seq_len_expand - 1) - index_range # (beam, max_len) - # >>> index - # >>> tensor([[ 2, 1, 0], - # >>> [ 2, 1, 0], - # >>> [ 0, -1, -2]]) + seq_mask = seq_len_expand > index_range # (batch*beam, max_len) + + index = (seq_len_expand - 1) - index_range # (batch*beam, max_len) index = index * seq_mask - # >>> index - # >>> tensor([[2, 1, 0], - # >>> [2, 1, 0], - # >>> [0, 0, 0]]) r_hyps = torch.gather(r_hyps, 1, index) - # >>> r_hyps - # >>> tensor([[3, 2, 1], - # >>> [4, 8, 9], - # >>> [2, 2, 2]]) r_hyps = torch.where(seq_mask, r_hyps, self.eos) - # >>> r_hyps - # >>> tensor([[3, 2, 1], - # >>> [4, 8, 9], - # >>> [2, eos, eos]]) r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) - # >>> r_hyps - # >>> tensor([[sos, 3, 2, 1], - # >>> [sos, 4, 8, 9], - # >>> [sos, 2, eos, eos]]) if self.decoder is not None: + # If using language-specific layers, handle cat_embs if self.lsl_dec: if verbose: print("passing cat_emb to decoder") @@ -960,21 +916,18 @@ def forward_attention_decoder( decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight, - cat_embs) # (num_hyps, max_hyps_len, vocab_size) + cat_embs) # (batch*beam, max_hyps_len, vocab_size) else: decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight, - None) # (num_hyps, max_hyps_len, vocab_size) + None) # (batch*beam, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) - - # right to left decoder may be not used during decoding process, - # which depends on reverse_weight param. - # r_dccoder_out will be 0.0, if reverse_weight is 0.0 - r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, - dim=-1) + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) else: decoder_out, r_decoder_out = None, None + return decoder_out, r_decoder_out def onmt_attention_decoding( diff --git a/asr/wenet/transformer/search.py b/asr/wenet/transformer/search.py index 9deb361..8a4b011 100644 --- a/asr/wenet/transformer/search.py +++ b/asr/wenet/transformer/search.py @@ -378,73 +378,113 @@ def attention_rescoring( device = encoder_outs.device assert encoder_outs.shape[0] == len(ctc_prefix_results) batch_size = encoder_outs.shape[0] - results = [] + + # Collect all hypotheses and their lengths + all_hyps = [] + all_ctc_scores = [] + beam_sizes = [] + for b in range(batch_size): + all_hyps.extend(ctc_prefix_results[b].nbest) + all_ctc_scores.extend(ctc_prefix_results[b].nbest_scores) + beam_sizes.append(len(ctc_prefix_results[b].nbest)) + + # Pad all hypotheses together + hyps_pad = pad_sequence([torch.tensor(hyp, device=device, dtype=torch.long) + for hyp in all_hyps], True, model.ignore_id) + hyps_lens = torch.tensor([len(hyp) for hyp in all_hyps], + device=device, dtype=torch.long) + + # Handle special tokens if needed + if getattr(model, 'special_tokens', None) is not None \ + and "transcribe" in model.special_tokens: + prev_len = hyps_pad.size(1) + # Repeat tasks and langs for each beam + tasks = [infos["tasks"][b] for b in range(batch_size) for _ in range(beam_sizes[b])] + langs = [infos["langs"][b] for b in range(batch_size) for _ in range(beam_sizes[b])] + hyps_pad, _ = add_whisper_tokens( + model.special_tokens, + hyps_pad, + model.ignore_id, + tasks=tasks, + no_timestamp=True, + langs=langs, + use_prev=False) + cur_len = hyps_pad.size(1) + hyps_lens = hyps_lens + cur_len - prev_len + prefix_len = 4 + else: + hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at beginning + prefix_len = 1 + + # Repeat encoder outputs for each beam + encoder_out_lens = [] + encoder_outs_expanded = [] for b in range(batch_size): + beam_size = beam_sizes[b] encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0) + encoder_outs_expanded.append(encoder_out.repeat(beam_size, 1, 1)) + encoder_out_lens.extend([encoder_lens[b]] * beam_size) + + encoder_outs_expanded = torch.cat(encoder_outs_expanded, dim=0) + + # Forward decoder with all hypotheses at once + decoder_out, r_decoder_out = model.forward_attention_decoder( + hyps_pad, hyps_lens, encoder_outs_expanded, reverse_weight, cat_embs) + + # Process results batch by batch + results = [] + offset = 0 + for b in range(batch_size): + beam_size = beam_sizes[b] hyps = ctc_prefix_results[b].nbest - ctc_scores = ctc_prefix_results[b].nbest_scores - hyps_pad = pad_sequence([ - torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps - ], True, model.ignore_id) # (beam_size, max_hyps_len) - hyps_lens = torch.tensor([len(hyp) for hyp in hyps], - device=device, - dtype=torch.long) # (beam_size,) - if getattr(model, 'special_tokens', None) is not None \ - and "transcribe" in model.special_tokens: - prev_len = hyps_pad.size(1) - hyps_pad, _ = add_whisper_tokens( - model.special_tokens, - hyps_pad, - model.ignore_id, - tasks=[infos["tasks"][b]] * len(hyps), - no_timestamp=True, - langs=[infos["langs"][b]] * len(hyps), - use_prev=False) - cur_len = hyps_pad.size(1) - hyps_lens = hyps_lens + cur_len - prev_len - prefix_len = 4 - else: - hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining - prefix_len = 1 - decoder_out, r_decoder_out = model.forward_attention_decoder( - hyps_pad, hyps_lens, encoder_out, reverse_weight, cat_embs) - # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 confidences = [] tokens_confidences = [] - for i, hyp in enumerate(hyps): + + # Process each hypothesis in the current batch + for i in range(beam_size): + idx = offset + i + hyp = hyps[i] score = 0.0 tc = [] # tokens confidences + + # Calculate forward decoder score for j, w in enumerate(hyp): - s = decoder_out[i][j + (prefix_len - 1)][w] + s = decoder_out[idx][j + (prefix_len - 1)][w] score += s tc.append(math.exp(s)) - score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos] - # add right to left decoder score + score += decoder_out[idx][len(hyp) + (prefix_len - 1)][eos] + + # Add right to left decoder score if needed if reverse_weight > 0 and r_decoder_out.dim() > 0: r_score = 0.0 for j, w in enumerate(hyp): - s = r_decoder_out[i][len(hyp) - j - 1 + - (prefix_len - 1)][w] + s = r_decoder_out[idx][len(hyp) - j - 1 + (prefix_len - 1)][w] r_score += s tc[j] = (tc[j] + math.exp(s)) / 2 - r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos] + r_score += r_decoder_out[idx][len(hyp) + (prefix_len - 1)][eos] score = score * (1 - reverse_weight) + r_score * reverse_weight + confidences.append(math.exp(score / (len(hyp) + 1))) - # add ctc score - score += ctc_scores[i] * ctc_weight + score += all_ctc_scores[idx] * ctc_weight + if score > best_score: best_score = score best_index = i tokens_confidences.append(tc) + + # Add best result for current batch results.append( DecodeResult(hyps[best_index], - best_score, - confidence=confidences[best_index], - times=ctc_prefix_results[b].nbest_times[best_index], - tokens_confidence=tokens_confidences[best_index])) + best_score, + confidence=confidences[best_index], + times=ctc_prefix_results[b].nbest_times[best_index], + tokens_confidence=tokens_confidences[best_index])) + + offset += beam_size + return results def joint_decoding( From 4e22cc81561043c47e65762428a3cba338b7a5fd Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 17:20:15 -0500 Subject: [PATCH 2/8] Removing redundant code --- asr/wenet/transformer/asr_model.py | 10 +--------- asr/wenet/transformer/search.py | 4 ++-- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/asr/wenet/transformer/asr_model.py b/asr/wenet/transformer/asr_model.py index 8ae6c5d..8a40faf 100644 --- a/asr/wenet/transformer/asr_model.py +++ b/asr/wenet/transformer/asr_model.py @@ -874,15 +874,7 @@ def forward_attention_decoder( cat_embs: Optional[torch.Tensor] = None, verbose: bool = False ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - batch_size = encoder_out.size(0) - beam_size = hyps.size(0) // batch_size - assert hyps.size(0) == batch_size * beam_size, "Number of hypotheses must be batch_size * beam_size" - assert hyps_lens.size(0) == batch_size * beam_size - - # Repeat encoder output for each hypothesis in the beam, maintaining batch separation - encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1) - encoder_out = encoder_out.view(batch_size * beam_size, -1, encoder_out.size(-1)) - encoder_mask = torch.ones(batch_size * beam_size, + encoder_mask = torch.ones(encoder_out.size(0), 1, encoder_out.size(1), dtype=torch.bool, diff --git a/asr/wenet/transformer/search.py b/asr/wenet/transformer/search.py index 8a4b011..b07ad85 100644 --- a/asr/wenet/transformer/search.py +++ b/asr/wenet/transformer/search.py @@ -423,9 +423,9 @@ def attention_rescoring( for b in range(batch_size): beam_size = beam_sizes[b] encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0) - encoder_outs_expanded.append(encoder_out.repeat(beam_size, 1, 1)) + encoder_outs_expanded.append(encoder_out.expand(beam_size, -1, -1)) encoder_out_lens.extend([encoder_lens[b]] * beam_size) - + encoder_outs_expanded = torch.cat(encoder_outs_expanded, dim=0) # Forward decoder with all hypotheses at once From df24b5d18278e58014e23d7ed650d3ccd36c90e0 Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 17:38:50 -0500 Subject: [PATCH 3/8] Cleanup and comments --- asr/wenet/bin/recognize_wav.py | 4 ---- asr/wenet/cli/reverb.py | 10 ---------- asr/wenet/transformer/asr_model.py | 30 ++++++++++++++++++++++++++++++ asr/wenet/transformer/search.py | 11 +++++++++++ 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/asr/wenet/bin/recognize_wav.py b/asr/wenet/bin/recognize_wav.py index 3cf112f..d4feafb 100644 --- a/asr/wenet/bin/recognize_wav.py +++ b/asr/wenet/bin/recognize_wav.py @@ -205,8 +205,4 @@ def main(): if __name__ == "__main__": - import time - start_time = time.time() main() - end_time = time.time() - logging.info(f"Total processing time: {end_time - start_time:.2f} seconds") diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index 978cff1..ab7af20 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -231,16 +231,6 @@ def transcribe_modes( infos={"tasks": ["transcribe"], "langs": ["en"]}, cat_embs=cat_embs, ) - print(get_output( - "txt", - self.tokenizer, - Path(audio_file).stem, - hyps["attention_rescoring"], - timings_adjustment, - chunk_size, - self.input_frame_length, - self.output_frame_length, - )) results.append(hyps) outputs = [] for mode in modes: diff --git a/asr/wenet/transformer/asr_model.py b/asr/wenet/transformer/asr_model.py index 8a40faf..5d744f4 100644 --- a/asr/wenet/transformer/asr_model.py +++ b/asr/wenet/transformer/asr_model.py @@ -874,6 +874,36 @@ def forward_attention_decoder( cat_embs: Optional[torch.Tensor] = None, verbose: bool = False ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining + (batch*beam, max_hyps_len) + hyps_lens (torch.Tensor): length of each hyp in hyps + (batch*beam) + encoder_out (torch.Tensor): corresponding encoder output + r_hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad eos at the begining which is used fo right to left decoder + reverse_weight: used for verfing whether used right to left decoder, + > 0 will use. + cat_embs (torch.Tensor): category embeddings + (1, cat_emb_dim) + Returns: + decoder_out (torch.Tensor): decoder output + (batch*beam, max_hyps_len, vocab_size) + r_decoder_out (torch.Tensor): decoder output for right to left decoder + (batch*beam, max_hyps_len, vocab_size) + """ + + batch_size = encoder_out.size(0) + beam_size = hyps.size(0) // batch_size + assert hyps.size(0) == batch_size * beam_size, "Number of hypotheses must be batch_size * beam_size" + assert hyps_lens.size(0) == batch_size * beam_size + + # Repeat encoder output for each hypothesis in the beam, maintaining batch separation + encoder_out = encoder_out.unsqueeze(1).expand(-1, beam_size, -1, -1) + encoder_out = encoder_out.view(batch_size * beam_size, -1, encoder_out.size(-1)) encoder_mask = torch.ones(encoder_out.size(0), 1, encoder_out.size(1), diff --git a/asr/wenet/transformer/search.py b/asr/wenet/transformer/search.py index b07ad85..01ee46d 100644 --- a/asr/wenet/transformer/search.py +++ b/asr/wenet/transformer/search.py @@ -389,8 +389,10 @@ def attention_rescoring( beam_sizes.append(len(ctc_prefix_results[b].nbest)) # Pad all hypotheses together + # hyps_pad: (batch*beam, max_hyps_len) hyps_pad = pad_sequence([torch.tensor(hyp, device=device, dtype=torch.long) for hyp in all_hyps], True, model.ignore_id) + # hyps_lens: (batch*beam) hyps_lens = torch.tensor([len(hyp) for hyp in all_hyps], device=device, dtype=torch.long) @@ -422,13 +424,22 @@ def attention_rescoring( encoder_outs_expanded = [] for b in range(batch_size): beam_size = beam_sizes[b] + # encoder_out: (1, max_len, encoder_dim) encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0) + # encoder_outs_expanded: (beam_size, max_len, encoder_dim) encoder_outs_expanded.append(encoder_out.expand(beam_size, -1, -1)) + # encoder_out_lens: (beam_size) encoder_out_lens.extend([encoder_lens[b]] * beam_size) + # encoder_outs_expanded: (batch*beam, max_len, encoder_dim) encoder_outs_expanded = torch.cat(encoder_outs_expanded, dim=0) # Forward decoder with all hypotheses at once + # MDR: forward_attention_decoder will rexpand the encoder_outs_expanded + # to (batch*beam, max_len, encoder_dim). Necessary to maintain C++ + # compatibility. + # decoder_out: (batch*beam, max_hyps_len, vocab_size) + # r_decoder_out: (batch*beam, max_hyps_len, vocab_size) decoder_out, r_decoder_out = model.forward_attention_decoder( hyps_pad, hyps_lens, encoder_outs_expanded, reverse_weight, cat_embs) From 041476b602c5b5ff09f9ee1a9af8357c041bf355 Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 17:39:17 -0500 Subject: [PATCH 4/8] return to name --- asr/wenet/cli/reverb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index ab7af20..13ecd91 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -237,7 +237,7 @@ def transcribe_modes( outputs.append(get_output( format, self.tokenizer, - Path(audio_file).stem, + Path(audio_file).name, list(chain(*(hyp[mode] for hyp in results))), timings_adjustment, chunk_size, From 030a18b458d5766e457c72d15549565128d8ab7a Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 17:39:41 -0500 Subject: [PATCH 5/8] extra spacing --- asr/wenet/cli/reverb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index 13ecd91..25f7c2f 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -232,6 +232,7 @@ def transcribe_modes( cat_embs=cat_embs, ) results.append(hyps) + outputs = [] for mode in modes: outputs.append(get_output( From 6eed15ff853b12950683575df714fb3449603711 Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 18:22:06 -0500 Subject: [PATCH 6/8] Deal with the edge case of the last feats being shorter than batch --- asr/wenet/cli/reverb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index 25f7c2f..2e140b1 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -162,7 +162,8 @@ def feats_batcher( # Apply padding if needed pad_amt = last_batch_num_feats - feats_batch.shape[1] if pad_amt > 0: - feats_lengths[-1] -= pad_amt + if last_batch_size == 1: + feats_lengths[-1] -= pad_amt feats_batch = F.pad( input=feats_batch, pad=(0, 0, 0, pad_amt, 0, 0), From c764dbb4f2a10561aa01c733e22394491bf038ff Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 22:11:12 -0500 Subject: [PATCH 7/8] Handling GPU when loading from cache --- asr/wenet/bin/recognize_wav.py | 2 +- asr/wenet/cli/reverb.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/asr/wenet/bin/recognize_wav.py b/asr/wenet/bin/recognize_wav.py index d4feafb..ff3f450 100644 --- a/asr/wenet/bin/recognize_wav.py +++ b/asr/wenet/bin/recognize_wav.py @@ -162,7 +162,7 @@ def main(): raise RuntimeError("One of either --model or (--checkpoint and --config) must be set.") if model_arg_set: - reverb_model = load_model(args.model) + reverb_model = load_model(args.model, args.gpu) else: reverb_model = ReverbASR( args.config, diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index 2e140b1..3dc2673 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -323,7 +323,8 @@ def get_output( def load_model( - model: str + model: str, + gpu: int = -1, ): """Loads a reverb model. If "model" points to a path that exists, tries to load a model using those files at "model". @@ -354,7 +355,8 @@ def load_model( logging.info(f"Loading the model with {config_path = } and {checkpoint_path = }") return ReverbASR( str(config_path), - str(checkpoint_path) + str(checkpoint_path), + gpu = gpu, ) From 3bb3d999a999dc1ad9009122c8b2ebd3f10bc0f5 Mon Sep 17 00:00:00 2001 From: Miguel del Rio Date: Wed, 5 Mar 2025 22:17:55 -0500 Subject: [PATCH 8/8] Moving things to GPU sooner --- asr/wenet/cli/reverb.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/asr/wenet/cli/reverb.py b/asr/wenet/cli/reverb.py index 3dc2673..2eab4eb 100644 --- a/asr/wenet/cli/reverb.py +++ b/asr/wenet/cli/reverb.py @@ -121,12 +121,11 @@ def compute_feats( ) -> torch.Tensor: waveform, sample_rate = torchaudio.load(audio_file, normalize=False) logging.info(f"detected sample rate: {sample_rate}") - waveform = waveform.to(torch.float) + waveform = waveform.to(torch.float).to(self.device) if sample_rate != resample_rate: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate - )(waveform) - waveform = waveform.to(self.device) + ).to(self.device)(waveform) feats = kaldi.fbank( waveform, num_mel_bins=num_mel_bins, @@ -151,14 +150,12 @@ def feats_batcher( feats_batch = infeats[ :, b * batch_num_feats : b * batch_num_feats + batch_num_feats, : ] - feats_lengths = torch.tensor([chunk_size] * batch_size, dtype=torch.int32) + feats_lengths = torch.tensor([chunk_size] * batch_size, dtype=torch.int32, device=self.device) if b == num_batches - 1: # last batch can be smaller than batch size last_batch_size = ceil(feats_batch.shape[1] / chunk_size) last_batch_num_feats = chunk_size * last_batch_size - feats_lengths = torch.tensor( - [chunk_size] * last_batch_size, dtype=torch.int32 - ) + feats_lengths = torch.tensor([chunk_size] * last_batch_size, dtype=torch.int32, device=self.device) # Apply padding if needed pad_amt = last_batch_num_feats - feats_batch.shape[1] if pad_amt > 0: @@ -172,7 +169,7 @@ def feats_batcher( ) yield feats_batch.reshape( -1, chunk_size, self.test_conf["fbank_conf"]["num_mel_bins"] - ), feats_lengths.to(self.device) + ), feats_lengths def transcribe_modes( self,