diff --git a/egs/emilia/CLAP/clsp/clap_datamodule.py b/egs/emilia/CLAP/clsp/clap_datamodule.py new file mode 100644 index 0000000000..8d0007f562 --- /dev/null +++ b/egs/emilia/CLAP/clsp/clap_datamodule.py @@ -0,0 +1,328 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from lhotse import CutSet, combine, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + DynamicBucketingSampler, + SimpleCutSampler, + UnsupervisedWaveformDataset, +) +from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class DataModule: + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="CLAP data related options", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=16, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: int = 1, + rank: int = 0, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = UnsupervisedWaveformDataset() + + if self.args.bucketing_sampler: + logging.info( + "Using DynamicBucketingSampler with strict FixedBucketBatchSizeConstraint." + ) + constraint = FixedBucketBatchSizeConstraint( + max_seq_len_buckets=self.args.max_seq_len_buckets, + batch_sizes=self.args.fixed_batch_sizes, + ) + train_sampler = DynamicBucketingSampler( + cuts_train, + constraint=constraint, + shuffle=True, + drop_last=True, + duration_bins=self.args.duration_bins, + buffer_size=self.args.num_buckets * 5000, + sync_buckets=True, + concurrent=False, + world_size=world_size, + rank=rank, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=4, + persistent_workers=True, + pin_memory=True, + prefetch_factor=16, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = 1, + rank: int = 0, + ) -> DataLoader: + logging.info("About to create dev dataset") + validate = UnsupervisedWaveformDataset() + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=4, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = UnsupervisedWaveformDataset() + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=4, + ) + return test_dl + + def estimate_duration_bins( + self, + cuts: CutSet, + world_size: int = 1, + rank: int = 0, + ) -> List[float]: + logging.info("Estimating duration bins for FixedBucketBatchSizeConstraint") + + dummy_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=True, + drop_last=True, + buffer_size=self.args.num_buckets * 5000, + sync_buckets=True, + concurrent=False, + world_size=world_size, + rank=rank, + ) + duration_bins = dummy_sampler.duration_bins + del dummy_sampler + return duration_bins + + @lru_cache() + def emilia_en_cuts(self) -> CutSet: + logging.info("About to get Emilia EN tars") + filenames = glob.glob("./download/Emilia/EN/*.tar") + logging.info(f"Loading Emilia {len(filenames)} tars in lazy mode") + return CutSet.from_webdataset( + filenames, + shuffle_shards=True, + split_by_worker=False, + split_by_node=False, + ) + + @lru_cache() + def paraspeechcaps_train_base_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps train-base shuffled cuts") + return load_manifest_lazy( + self.args.manifest_dir + / "paraspeechcaps_cuts_train_base_shuf-selected.jsonl.gz" + ) + + @lru_cache() + def paraspeechcaps_dev_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps dev cuts") + splits = ["voxceleb", "expresso", "ears"] + return combine( + load_manifest_lazy( + self.args.manifest_dir + / f"paraspeechcaps_cuts_dev-{s}-selected.jsonl.gz" + ) + for s in splits + ) + + @lru_cache() + def paraspeechcaps_test_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps test cuts") + splits = ["voxceleb", "expresso", "ears"] + return combine( + load_manifest_lazy( + self.args.manifest_dir + / f"paraspeechcaps_cuts_test-{s}-selected.jsonl.gz" + ) + for s in splits + ) + + @lru_cache() + def iemocap_cuts(self) -> CutSet: + logging.info("About to get iemocap cuts") + return load_manifest_lazy(self.args.manifest_dir / "iemocap_cuts_all.jsonl.gz") + + @lru_cache() + def ravdess_cuts(self) -> CutSet: + logging.info("About to get ravdess cuts") + return load_manifest_lazy(self.args.manifest_dir / "ravdess_cuts_all.jsonl.gz") + + @lru_cache() + def cremad_cuts(self) -> CutSet: + logging.info("About to get crema-d cuts") + return load_manifest_lazy(self.args.manifest_dir / "cremad_cuts_test.jsonl.gz") diff --git a/egs/emilia/CLAP/clsp/eval.sh b/egs/emilia/CLAP/clsp/eval.sh new file mode 100755 index 0000000000..bb7fab71c1 --- /dev/null +++ b/egs/emilia/CLAP/clsp/eval.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +export CUDA_VISIBLE_DEVICES=0 + +md=800 + +python clsp/eval_zero_shot_classification.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --max-duration $md diff --git a/egs/emilia/CLAP/clsp/eval_speech_text_retrieval.py b/egs/emilia/CLAP/clsp/eval_speech_text_retrieval.py new file mode 100755 index 0000000000..7aa97905b1 --- /dev/null +++ b/egs/emilia/CLAP/clsp/eval_speech_text_retrieval.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +from typing import Dict + +import torch +import torch.nn as nn +from clap_datamodule import DataModule +from transformers import AutoModel + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + return parser + + +def evaluate( + model: nn.Module, + test_dl: torch.utils.data.DataLoader, + caption_type: str, +) -> Dict[str, float]: + model.eval() + device = next(model.parameters()).device + + metrics = {} + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "clip_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + + with torch.no_grad(): + for _, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + audio_lens = batch["audio_lens"].to(device) + + if caption_type == "short_captions": + captions = [ + c.supervisions[0].custom[caption_type][0] for c in batch["cuts"] + ] + elif caption_type == "long_captions": + captions = [ + c.supervisions[0].custom[caption_type][-1] for c in batch["cuts"] + ] + else: + raise ValueError + + audio_features, text_features, _ = model( + text=captions, + audio=audio, + audio_lens=audio_lens, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + metrics_single_dataset = compute_metrics( + audio_features=torch.cat(eval_info["all_audio_features"]), + text_features=torch.cat(eval_info["all_text_features"]), + ) + metrics.update(metrics_single_dataset) + + result_dict = {"metrics": metrics} + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + logits_per_audio = audio_features @ text_features.t() + logits_per_text = logits_per_audio.t() + + metrics = {} + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + # details[f"{name}_ranks"] = ranking.detach().cpu().tolist() + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + model = AutoModel.from_pretrained( + "yfyeung/CLSP", + trust_remote_code=True, + ) + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + paraspeechcaps_test_cuts = datamodule.paraspeechcaps_test_cuts() + paraspeechcaps_test_dl = datamodule.test_dataloaders(paraspeechcaps_test_cuts) + + test_sets = [ + "paraspeechcaps_test", + ] + test_dls = [ + paraspeechcaps_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + model=model, + test_dl=test_dl, + caption_type="long_captions", + ) + metrics = result_dict["metrics"] + print(f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])) + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/clsp/eval_zero_shot_classification.py b/egs/emilia/CLAP/clsp/eval_zero_shot_classification.py new file mode 100755 index 0000000000..8060e4692d --- /dev/null +++ b/egs/emilia/CLAP/clsp/eval_zero_shot_classification.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +from typing import Dict + +import torch +import torch.nn as nn +from clap_datamodule import DataModule +from transformers import AutoModel + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + return parser + + +def map_iemocap_emotion_label_to_index(label: str) -> int: + label_map = { + "hap": 0, + "exc": 1, + "ang": 2, + "sad": 3, + "neu": 4, + } + return label_map[label] + + +def map_ravdess_emotion_label_to_index(label: str) -> int: + label_map = { + "angry": 0, + "calm": 1, + "disgust": 2, + "fearful": 3, + "happy": 4, + "sad": 5, + "surprised": 6, + "neutral": 7, + } + return label_map[label] + + +def map_ravdess_gender_label_to_index(label: str) -> int: + label_map = { + "male": 0, + "female": 1, + } + return label_map[label] + + +def map_cremad_emotion_label_to_index(label: str) -> int: + label_map = { + "H": 0, + "S": 1, + "A": 2, + "F": 3, + "D": 4, + "N": 5, + } + return label_map[label] + + +def map_cremad_age_label_to_index(label: str) -> int: + if label < 20: + index = 0 + elif label < 40: + index = 1 + elif label < 60: + index = 2 + else: + index = 3 + return index + + +def generate_iemocap_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a excited tone.", + "A speaker in a angry tone.", + "A speaker in a sad tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_emotion_prompts() -> str: + return [ + "A speaker in a angry tone.", + "A speaker in a calm tone.", + "A speaker in a disgust tone.", + "A speaker in a fear tone.", + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a surprised tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_gender_prompts() -> str: + return [ + "A male speaker.", + "A female speaker.", + ] + + +def generate_cremad_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a angry tone.", + "A speaker in a fear tone.", + "A speaker in a disgust tone.", + "A speaker in a neutral tone.", + ] + + +def generate_cremad_age_prompts() -> str: + return [ + "A child or young teenager speaker.", + "An adult speaker.", + "A middle-aged speaker.", + "An older or elder speaker.", + ] + + +def evaluate( + model: nn.Module, + test_set: str, + test_dl: torch.utils.data.DataLoader, +) -> Dict[str, float]: + model.eval() + device = next(model.parameters()).device + + metrics = {} + eval_info = { + "all_audio_features": [], + "all_gt_labels": [], + } + + if test_set == "iemocap_emotion": + prompts = generate_iemocap_emotion_prompts() + elif test_set == "ravdess_emotion": + prompts = generate_ravdess_emotion_prompts() + elif test_set == "ravdess_gender": + prompts = generate_ravdess_gender_prompts() + elif test_set == "cremad_emotion": + prompts = generate_cremad_emotion_prompts() + elif test_set == "cremad_age": + prompts = generate_cremad_age_prompts() + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + _, text_features, _ = model( + text=prompts, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + with torch.no_grad(): + for _, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + audio_lens = batch["audio_lens"].to(device) + + if test_set == "iemocap_emotion": + gt_labels = [ + map_iemocap_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_emotion": + gt_labels = [ + map_ravdess_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_gender": + gt_labels = [ + map_ravdess_gender_label_to_index(c.supervisions[0].gender) + for c in batch["cuts"] + ] + elif test_set == "cremad_emotion": + gt_labels = [ + map_cremad_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "cremad_age": + gt_labels = [ + map_cremad_age_label_to_index(c.supervisions[0].age) + for c in batch["cuts"] + ] + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + audio_features, _, _ = model( + audio=audio, + audio_lens=audio_lens, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_gt_labels"].extend(gt_labels) + + metrics_single_dataset = compute_metrics( + audio_features=torch.cat(eval_info["all_audio_features"]), + text_features=text_features.cpu(), + gt_labels=torch.tensor(eval_info["all_gt_labels"], dtype=torch.int64), + test_set=test_set, + ) + metrics.update(metrics_single_dataset) + + result_dict = {"metrics": metrics} + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + gt_labels: torch.Tensor, + test_set: str, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + + logits_per_audio = torch.matmul(audio_features, text_features.t()) + preds = logits_per_audio.argmax(dim=1) + + if test_set == "iemocap_emotion": + gt_labels = gt_labels.clamp(min=1) + preds = preds.clamp(min=1) + + wa = (preds == gt_labels).float().mean().item() + + recall_sum = 0.0 + num_classes = 0 + for cls_idx in torch.unique(gt_labels): + cls_idx = cls_idx.item() + cls_mask = gt_labels == cls_idx + recall = (preds[cls_mask] == cls_idx).float().mean().item() + recall_sum += recall + num_classes += 1 + print(f"{test_set}: cls {cls_idx}, recall {recall}") + uar = recall_sum / num_classes if num_classes > 0 else 0.0 + + return {"wa": wa, "uar": uar} + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + model = AutoModel.from_pretrained( + "yfyeung/CLSP", + trust_remote_code=True, + ) + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + iemocap_test_cuts = datamodule.iemocap_cuts() + iemocap_test_dl = datamodule.test_dataloaders(iemocap_test_cuts) + + ravdess_test_cuts = datamodule.ravdess_cuts() + ravdess_test_dl = datamodule.test_dataloaders(ravdess_test_cuts) + + cremad_test_cuts = datamodule.cremad_cuts() + cremad_test_dl = datamodule.test_dataloaders(cremad_test_cuts) + + test_sets = [ + "iemocap_emotion", + "ravdess_emotion", + "cremad_emotion", + "ravdess_gender", + "cremad_age", + ] + test_dls = [ + iemocap_test_dl, + ravdess_test_dl, + cremad_test_dl, + ravdess_test_dl, + cremad_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + model=model, + test_set=test_set, + test_dl=test_dl, + ) + metrics = result_dict["metrics"] + print(f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])) + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/glap/clap_datamodule.py b/egs/emilia/CLAP/glap/clap_datamodule.py new file mode 100644 index 0000000000..8229cc26e7 --- /dev/null +++ b/egs/emilia/CLAP/glap/clap_datamodule.py @@ -0,0 +1,319 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from lhotse import CutSet, combine, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + DynamicBucketingSampler, + SimpleCutSampler, + UnsupervisedWaveformDataset, +) +from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class DataModule: + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="CLAP data related options", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=16, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: int = 1, + rank: int = 0, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = UnsupervisedWaveformDataset() + + if self.args.bucketing_sampler: + logging.info( + "Using DynamicBucketingSampler with strict FixedBucketBatchSizeConstraint." + ) + constraint = FixedBucketBatchSizeConstraint( + max_seq_len_buckets=self.args.max_seq_len_buckets, + batch_sizes=self.args.fixed_batch_sizes, + ) + train_sampler = DynamicBucketingSampler( + cuts_train, + constraint=constraint, + shuffle=True, + drop_last=True, + duration_bins=self.args.duration_bins, + buffer_size=self.args.num_buckets * 5000, + sync_buckets=True, + concurrent=False, + world_size=world_size, + rank=rank, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=4, + persistent_workers=True, + pin_memory=True, + prefetch_factor=16, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = 1, + rank: int = 0, + ) -> DataLoader: + logging.info("About to create dev dataset") + validate = UnsupervisedWaveformDataset() + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=4, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = UnsupervisedWaveformDataset() + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=4, + ) + return test_dl + + def estimate_duration_bins( + self, + cuts: CutSet, + world_size: int = 1, + rank: int = 0, + ) -> List[float]: + logging.info("Estimating duration bins for FixedBucketBatchSizeConstraint") + + dummy_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=True, + drop_last=True, + buffer_size=self.args.num_buckets * 5000, + sync_buckets=True, + concurrent=False, + world_size=world_size, + rank=rank, + ) + duration_bins = dummy_sampler.duration_bins + del dummy_sampler + return duration_bins + + @lru_cache() + def emilia_en_cuts(self) -> CutSet: + logging.info("About to get Emilia EN tars") + filenames = glob.glob("./download/Emilia/EN/*.tar") + logging.info(f"Loading Emilia {len(filenames)} tars in lazy mode") + return CutSet.from_webdataset( + filenames, + shuffle_shards=True, + split_by_worker=False, + split_by_node=False, + ) + + @lru_cache() + def paraspeechcaps_train_base_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps train-base shuffled cuts") + return load_manifest_lazy( + self.args.manifest_dir + / "paraspeechcaps_cuts_train_base_shuf-selected.jsonl.gz" + ) + + @lru_cache() + def paraspeechcaps_dev_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps dev cuts") + splits = ["voxceleb", "expresso", "ears"] + return combine( + load_manifest_lazy( + self.args.manifest_dir + / f"paraspeechcaps_cuts_dev-{s}-selected.jsonl.gz" + ) + for s in splits + ) + + @lru_cache() + def paraspeechcaps_test_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps test cuts") + splits = ["voxceleb", "expresso", "ears"] + return combine( + load_manifest_lazy( + self.args.manifest_dir + / f"paraspeechcaps_cuts_test-{s}-selected.jsonl.gz" + ) + for s in splits + ) + + @lru_cache() + def iemocap_cuts(self) -> CutSet: + logging.info("About to get iemocap cuts") + return load_manifest_lazy(self.args.manifest_dir / "iemocap_cuts_all.jsonl.gz") + + @lru_cache() + def ravdess_cuts(self) -> CutSet: + logging.info("About to get ravdess cuts") + return load_manifest_lazy(self.args.manifest_dir / "ravdess_cuts_all.jsonl.gz") + + @lru_cache() + def cremad_cuts(self) -> CutSet: + logging.info("About to get crema-d cuts") + return load_manifest_lazy(self.args.manifest_dir / "cremad_cuts_test.jsonl.gz") diff --git a/egs/emilia/CLAP/glap/evaluate.sh b/egs/emilia/CLAP/glap/evaluate.sh new file mode 100755 index 0000000000..9b3e033945 --- /dev/null +++ b/egs/emilia/CLAP/glap/evaluate.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +export PYTHONPATH=/root/icefall:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=$1 + +md=800 + +exp_dir=glap/exp + +echo $exp_dir + +if true; then +python glap/evaluate_retrieval.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +fi + +if false; then +python glap/evaluate_zero_shot_classification.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +fi + +# python /root/busygpu/run.py & diff --git a/egs/emilia/CLAP/glap/evaluate_retrieval.py b/egs/emilia/CLAP/glap/evaluate_retrieval.py new file mode 100755 index 0000000000..3d2e1877af --- /dev/null +++ b/egs/emilia/CLAP/glap/evaluate_retrieval.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import torch +from clap_datamodule import DataModule +from glap_model import glap_inference + +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + + return params + + +def evaluate( + params: AttributeDict, + model: Any, + device: torch.device, + test_dl: torch.utils.data.DataLoader, + caption_type: str, + return_details: bool = False, +) -> Dict[str, float]: + """Run the Speech-Text Retrieval evaluation process.""" + metrics = {} + num_samples = 0 + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + eval_detail = { + "all_audio_paths": [], + "all_texts": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + audio_lens = batch["audio_lens"].to(device) + + if caption_type == "short_captions": + captions = [c.supervisions[0].short_captions[0] for c in batch["cuts"]] + elif caption_type == "long_captions": + captions = [c.supervisions[0].long_captions[-1] for c in batch["cuts"]] + else: + raise ValueError + + audio_features = model.encode_audio(audio, audio_lens) + text_features = model.encode_text(captions, device=device) + + num_samples += audio_features.shape[0] + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + if return_details: + eval_detail["all_audio_paths"].extend( + [c.recording.sources[0].source for c in batch["cuts"]] + ) + eval_detail["all_texts"].extend(captions) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + all_audio_features = torch.cat(eval_info["all_audio_features"]) + all_text_features = torch.cat(eval_info["all_text_features"]) + metrics_single_dataset, details_single_dataset = compute_metrics( + audio_features=all_audio_features, + text_features=all_text_features, + ) + metrics.update(metrics_single_dataset) + + if return_details: + details = {} + for k, ranks in details_single_dataset.items(): + if k == "audio_to_text_ranks": + src_list = eval_detail["all_audio_paths"] + tgt_list = eval_detail["all_texts"] + elif k == "text_to_audio_ranks": + src_list = eval_detail["all_texts"] + tgt_list = eval_detail["all_audio_paths"] + else: + raise ValueError + + details[k] = { + src_list[i]: [ + f"GT# {tgt_list[j]}" if j == i else tgt_list[j] for j in ranking + ] + for i, ranking in enumerate(ranks) + } + + result_dict = {"metrics": metrics} + if return_details: + result_dict["details"] = details + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + logits_per_audio = 100 * (audio_features @ text_features.t()) + logits_per_text = logits_per_audio.t() + + metrics = {} + metrics["num_samples"] = N + + details = {} + + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + # preds = torch.where(ranking == ground_truth)[1] + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + details[f"{name}_ranks"] = ranking.detach().cpu().tolist() + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics, details + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "speech-text-retrieval" + + setup_logger(f"{params.res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = glap_inference() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + paraspeechcaps_test_cuts = datamodule.paraspeechcaps_test_cuts() + paraspeechcaps_test_dl = datamodule.test_dataloaders(paraspeechcaps_test_cuts) + + test_sets = [ + "paraspeechcaps_test", + ] + test_dls = [ + paraspeechcaps_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + device=device, + test_dl=test_dl, + caption_type="short_captions", + return_details=True, + ) + metrics = result_dict["metrics"] + details = result_dict["details"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + with open(f"{params.res_dir}/details-decode", "w", encoding="utf-8") as f: + json.dump(details, f, ensure_ascii=False, indent=2) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/glap/evaluate_zero_shot_classification.py b/egs/emilia/CLAP/glap/evaluate_zero_shot_classification.py new file mode 100755 index 0000000000..c85804d245 --- /dev/null +++ b/egs/emilia/CLAP/glap/evaluate_zero_shot_classification.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from typing import Any, Dict + +import torch +from clap_datamodule import DataModule +from glap_model import glap_inference + +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + + return params + + +def map_iemocap_emotion_label_to_index(label: str) -> int: + label_map = { + "hap": 0, + "exc": 1, + "ang": 2, + "sad": 3, + "neu": 4, + } + return label_map[label] + + +def map_ravdess_emotion_label_to_index(label: str) -> int: + label_map = { + "angry": 0, + "calm": 1, + "disgust": 2, + "fearful": 3, + "happy": 4, + "sad": 5, + "surprised": 6, + "neutral": 7, + } + return label_map[label] + + +def map_ravdess_gender_label_to_index(label: str) -> int: + label_map = { + "male": 0, + "female": 1, + } + return label_map[label] + + +def map_cremad_emotion_label_to_index(label: str) -> int: + label_map = { + "H": 0, + "S": 1, + "A": 2, + "F": 3, + "D": 4, + "N": 5, + } + return label_map[label] + + +def map_cremad_age_label_to_index(label: str) -> int: + if label < 20: + index = 0 + elif label < 40: + index = 1 + elif label < 60: + index = 2 + else: + index = 3 + return index + + +def generate_iemocap_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a excited tone.", + "A speaker in a angry tone.", + "A speaker in a sad tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_emotion_prompts() -> str: + return [ + "A speaker in a angry tone.", + "A speaker in a calm tone.", + "A speaker in a disgust tone.", + "A speaker in a fear tone.", + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a surprised tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_gender_prompts() -> str: + return [ + "A male speaker.", + "A female speaker.", + ] + + +def generate_cremad_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a angry tone.", + "A speaker in a fear tone.", + "A speaker in a disgust tone.", + "A speaker in a neutral tone.", + ] + + +def generate_cremad_age_prompts() -> str: + return [ + "A child or young teenager speaker.", + "An adult speaker.", + "A middle-aged speaker.", + "An older or elder speaker.", + ] + + +def evaluate( + params: AttributeDict, + model: Any, + device: torch.device, + test_set: str, + test_dl: torch.utils.data.DataLoader, +) -> Dict[str, float]: + """Run the Zero-Shot Classification evaluation process.""" + metrics = {} + eval_info = { + "all_audio_features": [], + "all_gt_labels": [], + } + + if test_set == "iemocap_emotion": + prompts = generate_iemocap_emotion_prompts() + elif test_set == "ravdess_emotion": + prompts = generate_ravdess_emotion_prompts() + elif test_set == "ravdess_gender": + prompts = generate_ravdess_gender_prompts() + elif test_set == "cremad_emotion": + prompts = generate_cremad_emotion_prompts() + elif test_set == "cremad_age": + prompts = generate_cremad_age_prompts() + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + text_features = model.encode_text(prompts, device=device) + + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + audio_lens = batch["audio_lens"].to(device) + + if test_set == "iemocap_emotion": + gt_labels = [ + map_iemocap_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_emotion": + gt_labels = [ + map_ravdess_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_gender": + gt_labels = [ + map_ravdess_gender_label_to_index(c.supervisions[0].gender) + for c in batch["cuts"] + ] + elif test_set == "cremad_emotion": + gt_labels = [ + map_cremad_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "cremad_age": + gt_labels = [ + map_cremad_age_label_to_index(c.supervisions[0].age) + for c in batch["cuts"] + ] + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + audio_features = model.encode_audio(audio, audio_lens) + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_gt_labels"].extend(gt_labels) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + all_audio_features = torch.cat(eval_info["all_audio_features"]) + all_text_features = text_features.cpu() + all_gt_labels = torch.tensor(eval_info["all_gt_labels"], dtype=torch.int64) + metrics_single_dataset = compute_metrics( + audio_features=all_audio_features, + text_features=all_text_features, + gt_labels=all_gt_labels, + test_set=test_set, + ) + metrics.update(metrics_single_dataset) + + result_dict = {"metrics": metrics} + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + gt_labels: torch.Tensor, + test_set: str, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + + logits_per_audio = torch.matmul(audio_features, text_features.t()) + preds = logits_per_audio.argmax(dim=1) + + if test_set == "iemocap_emotion": + gt_labels = gt_labels.clamp(min=1) + preds = preds.clamp(min=1) + + wa = (preds == gt_labels).float().mean().item() + + recall_sum = 0.0 + num_classes = 0 + for cls_idx in torch.unique(gt_labels): + cls_idx = cls_idx.item() + cls_mask = gt_labels == cls_idx + recall = (preds[cls_mask] == cls_idx).float().mean().item() + recall_sum += recall + num_classes += 1 + logging.info(f"{test_set}: cls {cls_idx}, recall {recall}") + uar = recall_sum / num_classes if num_classes > 0 else 0.0 + + return {"wa": wa, "uar": uar} + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "zero-shot-classification" + + setup_logger(f"{params.res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = glap_inference() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + iemocap_test_cuts = datamodule.iemocap_cuts() + iemocap_test_dl = datamodule.test_dataloaders(iemocap_test_cuts) + + ravdess_test_cuts = datamodule.ravdess_cuts() + ravdess_test_dl = datamodule.test_dataloaders(ravdess_test_cuts) + + cremad_test_cuts = datamodule.cremad_cuts() + cremad_test_dl = datamodule.test_dataloaders(cremad_test_cuts) + + test_sets = [ + "iemocap_emotion", + "ravdess_emotion", + "cremad_emotion", + "ravdess_gender", + "cremad_age", + ] + test_dls = [ + iemocap_test_dl, + ravdess_test_dl, + cremad_test_dl, + ravdess_test_dl, + cremad_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + device=device, + test_set=test_set, + test_dl=test_dl, + ) + metrics = result_dict["metrics"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/laion_clap/clap_datamodule.py b/egs/emilia/CLAP/laion_clap/clap_datamodule.py new file mode 120000 index 0000000000..1ab77496dd --- /dev/null +++ b/egs/emilia/CLAP/laion_clap/clap_datamodule.py @@ -0,0 +1 @@ +../glap/clap_datamodule.py \ No newline at end of file diff --git a/egs/emilia/CLAP/laion_clap/evaluate.sh b/egs/emilia/CLAP/laion_clap/evaluate.sh new file mode 100755 index 0000000000..589977b253 --- /dev/null +++ b/egs/emilia/CLAP/laion_clap/evaluate.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +export PYTHONPATH=/root/icefall:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=$1 + +md=800 + +exp_dir=laion_clap/exp + +echo $exp_dir + +if false; then +python laion_clap/evaluate_retrieval.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +fi + +if true; then +python laion_clap/evaluate_zero_shot_classification.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +fi + +# python /root/busygpu/run.py & diff --git a/egs/emilia/CLAP/laion_clap/evaluate_retrieval.py b/egs/emilia/CLAP/laion_clap/evaluate_retrieval.py new file mode 100755 index 0000000000..e01ccbb35c --- /dev/null +++ b/egs/emilia/CLAP/laion_clap/evaluate_retrieval.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import torch +from clap_datamodule import DataModule +from laion_clap import CLAP_Module + +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + + return params + + +def evaluate( + params: AttributeDict, + model: Any, + device: torch.device, + test_dl: torch.utils.data.DataLoader, + caption_type: str, + return_details: bool = False, +) -> Dict[str, float]: + """Run the Speech-Text Retrieval evaluation process.""" + metrics = {} + num_samples = 0 + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + eval_detail = { + "all_audio_paths": [], + "all_texts": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + # audio_lens = batch["audio_lens"].to(device) + + if caption_type == "short_captions": + captions = [c.supervisions[0].short_captions[0] for c in batch["cuts"]] + elif caption_type == "long_captions": + captions = [c.supervisions[0].long_captions[-1] for c in batch["cuts"]] + else: + raise ValueError + + audio_features = model.get_audio_embedding_from_data(audio, use_tensor=True) + text_features = model.get_text_embedding(captions, use_tensor=True) + + num_samples += audio_features.shape[0] + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + if return_details: + eval_detail["all_audio_paths"].extend( + [c.recording.sources[0].source for c in batch["cuts"]] + ) + eval_detail["all_texts"].extend(captions) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + all_audio_features = torch.cat(eval_info["all_audio_features"]) + all_text_features = torch.cat(eval_info["all_text_features"]) + metrics_single_dataset, details_single_dataset = compute_metrics( + audio_features=all_audio_features, + text_features=all_text_features, + ) + metrics.update(metrics_single_dataset) + + if return_details: + details = {} + for k, ranks in details_single_dataset.items(): + if k == "audio_to_text_ranks": + src_list = eval_detail["all_audio_paths"] + tgt_list = eval_detail["all_texts"] + elif k == "text_to_audio_ranks": + src_list = eval_detail["all_texts"] + tgt_list = eval_detail["all_audio_paths"] + else: + raise ValueError + + details[k] = { + src_list[i]: [ + f"GT# {tgt_list[j]}" if j == i else tgt_list[j] for j in ranking + ] + for i, ranking in enumerate(ranks) + } + + result_dict = {"metrics": metrics} + if return_details: + result_dict["details"] = details + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + logits_per_audio = audio_features @ text_features.t() + logits_per_text = logits_per_audio.t() + + metrics = {} + metrics["num_samples"] = N + + details = {} + + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + # preds = torch.where(ranking == ground_truth)[1] + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + details[f"{name}_ranks"] = ranking.detach().cpu().tolist() + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics, details + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "speech-text-retrieval" + + setup_logger(f"{params.res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = CLAP_Module(enable_fusion=False) + model.load_ckpt() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + paraspeechcaps_test_cuts = datamodule.paraspeechcaps_test_cuts() + paraspeechcaps_test_dl = datamodule.test_dataloaders(paraspeechcaps_test_cuts) + + test_sets = [ + "paraspeechcaps_test", + ] + test_dls = [ + paraspeechcaps_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + device=device, + test_dl=test_dl, + caption_type="long_captions", + return_details=True, + ) + metrics = result_dict["metrics"] + details = result_dict["details"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + with open(f"{params.res_dir}/details-decode", "w", encoding="utf-8") as f: + json.dump(details, f, ensure_ascii=False, indent=2) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/laion_clap/evaluate_zero_shot_classification.py b/egs/emilia/CLAP/laion_clap/evaluate_zero_shot_classification.py new file mode 100755 index 0000000000..2ce1eecffa --- /dev/null +++ b/egs/emilia/CLAP/laion_clap/evaluate_zero_shot_classification.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from typing import Any, Dict + +import torch +from clap_datamodule import DataModule +from laion_clap import CLAP_Module + +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + + return params + + +def map_iemocap_emotion_label_to_index(label: str) -> int: + label_map = { + "hap": 0, + "exc": 1, + "ang": 2, + "sad": 3, + "neu": 4, + } + return label_map[label] + + +def map_ravdess_emotion_label_to_index(label: str) -> int: + label_map = { + "angry": 0, + "calm": 1, + "disgust": 2, + "fearful": 3, + "happy": 4, + "sad": 5, + "surprised": 6, + "neutral": 7, + } + return label_map[label] + + +def map_ravdess_gender_label_to_index(label: str) -> int: + label_map = { + "male": 0, + "female": 1, + } + return label_map[label] + + +def map_cremad_emotion_label_to_index(label: str) -> int: + label_map = { + "H": 0, + "S": 1, + "A": 2, + "F": 3, + "D": 4, + "N": 5, + } + return label_map[label] + + +def map_cremad_age_label_to_index(label: str) -> int: + if label < 20: + index = 0 + elif label < 40: + index = 1 + elif label < 60: + index = 2 + else: + index = 3 + return index + + +def generate_iemocap_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a excited tone.", + "A speaker in a angry tone.", + "A speaker in a sad tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_emotion_prompts() -> str: + return [ + "A speaker in a angry tone.", + "A speaker in a calm tone.", + "A speaker in a disgust tone.", + "A speaker in a fear tone.", + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a surprised tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_gender_prompts() -> str: + return [ + "A male speaker.", + "A female speaker.", + ] + + +def generate_cremad_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a angry tone.", + "A speaker in a fear tone.", + "A speaker in a disgust tone.", + "A speaker in a neutral tone.", + ] + + +def generate_cremad_age_prompts() -> str: + return [ + "A child or young teenager speaker.", + "An adult speaker.", + "A middle-aged speaker.", + "An older or elder speaker.", + ] + + +def evaluate( + params: AttributeDict, + model: Any, + device: torch.device, + test_set: str, + test_dl: torch.utils.data.DataLoader, +) -> Dict[str, float]: + """Run the Zero-Shot Classification evaluation process.""" + metrics = {} + eval_info = { + "all_audio_features": [], + "all_gt_labels": [], + } + + if test_set == "iemocap_emotion": + prompts = generate_iemocap_emotion_prompts() + elif test_set == "ravdess_emotion": + prompts = generate_ravdess_emotion_prompts() + elif test_set == "ravdess_gender": + prompts = generate_ravdess_gender_prompts() + elif test_set == "cremad_emotion": + prompts = generate_cremad_emotion_prompts() + elif test_set == "cremad_age": + prompts = generate_cremad_age_prompts() + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + text_features = model.get_text_embedding(prompts, use_tensor=True) + + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + audio_lens = batch["audio_lens"].to(device) + + if test_set == "iemocap_emotion": + gt_labels = [ + map_iemocap_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_emotion": + gt_labels = [ + map_ravdess_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_gender": + gt_labels = [ + map_ravdess_gender_label_to_index(c.supervisions[0].gender) + for c in batch["cuts"] + ] + elif test_set == "cremad_emotion": + gt_labels = [ + map_cremad_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "cremad_age": + gt_labels = [ + map_cremad_age_label_to_index(c.supervisions[0].age) + for c in batch["cuts"] + ] + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + audio_features = model.get_audio_embedding_from_data(audio, use_tensor=True) + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_gt_labels"].extend(gt_labels) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + all_audio_features = torch.cat(eval_info["all_audio_features"]) + all_text_features = text_features.cpu() + all_gt_labels = torch.tensor(eval_info["all_gt_labels"], dtype=torch.int64) + metrics_single_dataset = compute_metrics( + audio_features=all_audio_features, + text_features=all_text_features, + gt_labels=all_gt_labels, + test_set=test_set, + ) + metrics.update(metrics_single_dataset) + + result_dict = {"metrics": metrics} + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + gt_labels: torch.Tensor, + test_set: str, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + + logits_per_audio = torch.matmul(audio_features, text_features.t()) + preds = logits_per_audio.argmax(dim=1) + + if test_set == "iemocap_emotion": + gt_labels = gt_labels.clamp(min=1) + preds = preds.clamp(min=1) + + wa = (preds == gt_labels).float().mean().item() + + recall_sum = 0.0 + num_classes = 0 + for cls_idx in torch.unique(gt_labels): + cls_idx = cls_idx.item() + cls_mask = gt_labels == cls_idx + recall = (preds[cls_mask] == cls_idx).float().mean().item() + recall_sum += recall + num_classes += 1 + logging.info(f"{test_set}: cls {cls_idx}, recall {recall}") + uar = recall_sum / num_classes if num_classes > 0 else 0.0 + + return {"wa": wa, "uar": uar} + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "zero-shot-classification" + + setup_logger(f"{params.res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = CLAP_Module(enable_fusion=False) + model.load_ckpt() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + iemocap_test_cuts = datamodule.iemocap_cuts() + iemocap_test_dl = datamodule.test_dataloaders(iemocap_test_cuts) + + ravdess_test_cuts = datamodule.ravdess_cuts() + ravdess_test_dl = datamodule.test_dataloaders(ravdess_test_cuts) + + cremad_test_cuts = datamodule.cremad_cuts() + cremad_test_dl = datamodule.test_dataloaders(cremad_test_cuts) + + test_sets = [ + "iemocap_emotion", + "ravdess_emotion", + "cremad_emotion", + "ravdess_gender", + "cremad_age", + ] + test_dls = [ + iemocap_test_dl, + ravdess_test_dl, + cremad_test_dl, + ravdess_test_dl, + cremad_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + device=device, + test_set=test_set, + test_dl=test_dl, + ) + metrics = result_dict["metrics"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/local/attach_long_captions.py b/egs/emilia/CLAP/local/attach_long_captions.py new file mode 100644 index 0000000000..3a08c7022d --- /dev/null +++ b/egs/emilia/CLAP/local/attach_long_captions.py @@ -0,0 +1,303 @@ +import os + +os.environ["VLLM_USE_V1"] = "0" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import argparse +import multiprocessing as mp +import time +from base64 import b64encode +from io import BytesIO +from multiprocessing import Process, Queue +from pathlib import Path +from string import Template + +import torch + +os.environ["OMP_NUM_THREADS"] = str( + min(8, os.cpu_count() // torch.cuda.device_count() + 2) +) +os.environ["MKL_NUM_THREADS"] = str( + min(8, os.cpu_count() // torch.cuda.device_count() + 2) +) + +import soundfile as sf +from datasets import load_dataset +from datasets.features import Audio +from lhotse import CutSet, load_manifest_lazy +from qwen_omni_utils import process_mm_info +from tqdm import tqdm +from transformers import Qwen3OmniMoeProcessor +from vllm import LLM, SamplingParams + +MODEL_PATH = "./download/Qwen3-Omni-30B-A3B-Captioner" +MAX_TOKENS = 512 +MAX_MODEL_LEN = 2048 +MAX_SAMPLES_IN_QUEUE = 100_000 +USER_PROMPT = Template( + """Your task is to generate a caption describing **only the characteristics of the speaker's voice**. + +Use the following tags in the caption: +$tag_block + +### CRITICAL RULES +1. **NEVER** describe the content of the speech. Do not quote any words or phrases. **NEVER** contain quotation marks (""). +2. **FOCUS ONLY ON THE HUMAN VOICE**. **NEVER** describe background, environment, audio quality. +3. **NEVER** mention the absence of characteristics (describe only what is present, not mention what is not present). +4. **NEVER** over-interpret or guess. +5. Failure to follow these rules will result in an invalid output. + +-- + +### Good Example +A young male with a clear, medium-high pitched voice and an American accent speaks in a casual, conversational style, much like a reviewer or vlogger. He begins at a fast, rushed pace with a highly energetic and emphatic intonation, using a high pitch to express strong emphasis. After a slight inhale, he continues to speak quickly and enthusiastically, maintaining a moderately loud volume and an expressive, fluctuating tone throughout the fluent delivery. + +--- + +### YOUR CAPTION:""" +) + + +def set_affinity_for_process(rank, total): + num_cpus = os.cpu_count() + cpus_per_proc = min(8, num_cpus // total) + start = rank * cpus_per_proc + end = min(num_cpus, start + cpus_per_proc) + os.sched_setaffinity(0, range(start, end)) + print(f"[PID {os.getpid()}] bound to CPUs {list(range(start, end))}") + + +def build_input(processor, messages): + text = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + audios, images, videos = process_mm_info(messages, use_audio_in_video=True) + + inputs = { + "prompt": text, + "multi_modal_data": {}, + "mm_processor_kwargs": { + "use_audio_in_video": False, + }, + } + + if images is not None: + inputs["multi_modal_data"]["image"] = images + if videos is not None: + inputs["multi_modal_data"]["video"] = videos + if audios is not None: + inputs["multi_modal_data"]["audio"] = audios + + return inputs + + +def producer(cuts_paths, queue, skip_uids): + set_affinity_for_process(rank=1, total=torch.cuda.device_count() + 2) + + processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH) + pbar = tqdm(desc="Building conversations") + for cuts_path in cuts_paths: + cuts = load_manifest_lazy(cuts_path) + for cut in cuts: + while queue.qsize() > MAX_SAMPLES_IN_QUEUE: + print("Producer sleeping for queue to drain...") + time.sleep(10) + pbar.update(1) + + if cut.id in skip_uids: + continue + + if cut.duration >= 30: + print("Skip audio duration larger than 30s") + continue + + cut = cut.resample(16000) + audio = cut.load_audio() + sr = cut.sampling_rate + + audio_buffer = BytesIO() + sf.write(audio_buffer, audio.T, sr, format="wav") + audio_bytes = audio_buffer.getvalue() + audio_b64 = "data:audio/wav;base64," + b64encode(audio_bytes).decode( + "utf-8" + ) + + tags = [] + accent = cut.supervisions[0].custom["accent"] + speaking_rate = cut.supervisions[0].custom["speaking_rate"] + situational_tags = cut.supervisions[0].custom["situational_tags"] + if accent: + tags.append(f"- **Accent**: {accent}") + if speaking_rate: + tags.append(f"- **Speaking Rate**: {speaking_rate}") + if situational_tags: + situational_tags = ", ".join(situational_tags) + tags.append(f"- **Emotion / Expressiveness**: {situational_tags}") + tag_block = "\n".join(tags) + user_prompt = USER_PROMPT.substitute(tag_block=tag_block) + + conversation = [ + { + "role": "user", + "content": [ + {"type": "audio", "audio": audio_b64}, + {"type": "text", "text": user_prompt}, + ], + } + ] + # the text are same across all samples + input_ = build_input(processor, conversation) + queue.put((Path(cuts_path).stem, cut, input_)) + pbar.close() + for _ in range(torch.cuda.device_count()): + queue.put(None) + + +def consumer( + producer_queue, consumer_queue, device, sampling_params, batch_size=64, seed=42 +): + set_affinity_for_process(rank=device + 2, total=torch.cuda.device_count() + 2) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(device) + + llm = LLM( + model=MODEL_PATH, + trust_remote_code=True, + gpu_memory_utilization=0.97, + tensor_parallel_size=1, + limit_mm_per_prompt={"image": 0, "video": 0, "audio": 1}, + max_num_seqs=64, + max_num_batched_tokens=32768, + max_model_len=MAX_MODEL_LEN, + seed=seed, + ) + + cutsnames, cuts, inputs = [], [], [] + + def process_batch(cutsnames, cuts, inputs): + outputs = llm.generate(inputs, sampling_params=sampling_params) + for cutsname, cut, output in zip(cutsnames, cuts, outputs): + for result in output.outputs: + cut.supervisions[0].long_captions.append(result.text.strip()) + consumer_queue.put((cutsname, cut)) + + while True: + item = producer_queue.get() + if item is None: + break + + cutsname, cut, input_ = item + cutsnames.append(cutsname) + cuts.append(cut) + inputs.append(input_) + if len(inputs) < batch_size: + continue + process_batch(cutsnames, cuts, inputs) + cutsnames, cuts, inputs = [], [], [] + + if len(inputs) > 0: + process_batch(cutsnames, cuts, inputs) + consumer_queue.put(None) + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + + set_affinity_for_process(rank=0, total=torch.cuda.device_count() + 2) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuts_path", type=Path, help="Path to the input cuts list file." + ) + parser.add_argument("--tasks", type=Path, help="Path to the input task list file.") + parser.add_argument( + "--output_dir", + type=Path, + default=Path("./data/manifests"), + help="Path to the output directory", + ) + parser.add_argument( + "--seed", type=int, default=1234, help="Random seed for initialization." + ) + parser.add_argument( + "-b", "--batch_size", type=int, default=64, help="Batch size for processing." + ) + parser.add_argument( + "-n", + "--n_results_per_sample", + type=int, + default=1, + help="Number of results per sample.", + ) + args = parser.parse_args() + + if args.tasks is not None: + cuts_paths = [Path(line) for line in args.tasks.read_text().splitlines()] + else: + cuts_paths = [args.cuts_path] + + cutsname2jsonl_f = {} + skip_uids = set() + for cuts_path in cuts_paths: + output_path = ( + args.output_dir + / f"{cuts_path.name.replace(''.join(cuts_path.suffixes), '')}-attached.jsonl.gz" + ) + if output_path.exists(): + print(f"{output_path} already exists, about to load...") + cuts = load_manifest_lazy(output_path) + for cut in cuts: + skip_uids.add(cut.id) + cutsname2jsonl_f[Path(cuts_path).stem] = CutSet.open_writer( + output_path, overwrite=False + ) + + cuts_paths = [str(p) for p in cuts_paths] + producer_queue, consumer_queue = Queue(), Queue() + Process( + target=producer, + args=(cuts_paths, producer_queue, skip_uids), + daemon=True, + ).start() + + sampling_params = SamplingParams( + temperature=0.6, + top_p=0.95, + top_k=20, + max_tokens=MAX_TOKENS, + n=args.n_results_per_sample, + ) + for device in range(torch.cuda.device_count()): + Process( + target=consumer, + args=( + producer_queue, + consumer_queue, + device, + sampling_params, + args.batch_size, + args.seed, + ), + daemon=True, + ).start() + + remaining_consumers = torch.cuda.device_count() + pbar = tqdm(desc="inference") + while True: + record = consumer_queue.get() + if record is None: + remaining_consumers -= 1 + if remaining_consumers == 0: + break + continue + cutsname, cut = record + f = cutsname2jsonl_f[cutsname] + pbar.update(1) + f.write(cut) + pbar.close() + + for f in cutsname2jsonl_f.values(): + f.close() diff --git a/egs/emilia/CLAP/local/convert_paraspeechcaps_hf_to_jsonl.py b/egs/emilia/CLAP/local/convert_paraspeechcaps_hf_to_jsonl.py new file mode 100644 index 0000000000..5831307a24 --- /dev/null +++ b/egs/emilia/CLAP/local/convert_paraspeechcaps_hf_to_jsonl.py @@ -0,0 +1,76 @@ +import json +import os +from collections import defaultdict + +from datasets import load_dataset + +splits = [ + "holdout", + "test", + "dev", + "train_base", + # "train_scaled", +] + +os.makedirs("data/manifests", exist_ok=True) + +for split in splits: + print(f"Processing split: {split}") + + ds = load_dataset("ajd12342/paraspeechcaps", split=split) + + data2sample = defaultdict(list) + for sample in ds: + data2sample[sample["source"]].append(sample) + + for source, samples in data2sample.items(): + output_path = os.path.join( + "data/manifests", f"paraspeechcaps_{split}-{source}.jsonl" + ) + + if os.path.exists(output_path): + print(f"{output_path} exists, skip") + continue + + with open(output_path, "w", encoding="utf-8") as f: + for sample in samples: + if source == "voxceleb": + audio_path = sample["relative_audio_path"].replace( + "_voicefixer", "" + ) + elif source == "expresso": + audio_path = os.path.join("expresso", sample["relative_audio_path"]) + elif source == "ears": + audio_path = os.path.join("ears", sample["relative_audio_path"]) + elif source == "emilia": + audio_path = os.path.join("Emilia", sample["relative_audio_path"]) + else: + raise ValueError + + audio_path = os.path.join("download", audio_path) + text = sample["transcription"] + caption = sample["text_description"] + + intrinsic_tags = sample["intrinsic_tags"] + situational_tags = sample["situational_tags"] + speaker = sample["name"] + gender = sample["gender"] + accent = sample["accent"] + pitch = sample["pitch"] + speaking_rate = sample["speaking_rate"] + + obj = { + "audio_path": audio_path, + "text": text, + "caption": caption, + "intrinsic_tags": intrinsic_tags, + "situational_tags": situational_tags, + "speaker": speaker, + "gender": gender, + "accent": accent, + "pitch": pitch, + "speaking_rate": speaking_rate, + } + f.write(json.dumps(obj, ensure_ascii=False) + "\n") + + print(split, source, len(samples), "->", output_path) diff --git a/egs/emilia/CLAP/local/describe_caption_token_lengths.py b/egs/emilia/CLAP/local/describe_caption_token_lengths.py new file mode 100644 index 0000000000..11e31177e8 --- /dev/null +++ b/egs/emilia/CLAP/local/describe_caption_token_lengths.py @@ -0,0 +1,132 @@ +import os +import sys +from multiprocessing import Pool, cpu_count + +import numpy as np +from lhotse import CutSet +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +MANIFEST = sys.argv[1] + +print("Loading CutSet...") +cuts = CutSet.from_file(MANIFEST) + +short_samples = [] +long_samples = [] + +for cut in tqdm(cuts, desc="Collecting captions"): + audio_src = cut.recording.sources[0].source + + for sup in cut.supervisions: + custom = sup.custom + for cap in custom["short_captions"]: + short_samples.append({"audio": audio_src, "caption": cap}) + for cap in custom["long_captions"]: + long_samples.append({"audio": audio_src, "caption": cap}) + +print(f"#short_captions = {len(short_samples)}") +print(f"#long_captions = {len(long_samples)}") + +short_texts = [s["caption"] for s in short_samples] +long_texts = [s["caption"] for s in long_samples] + +_tokenizer = None + + +def _init_tokenizer(): + global _tokenizer + from transformers import RobertaTokenizer + + _tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + +def _token_length(text: str) -> int: + global _tokenizer + enc = _tokenizer( + text, + padding=False, + truncation=False, + return_attention_mask=False, + ) + return len(enc["input_ids"]) + + +def compute_lengths_mp(texts, num_workers: int | None = None, desc: str = "Tokenizing"): + if num_workers is None: + num_workers = min(80, cpu_count() - 1) + + print(f"{desc}: using {num_workers} workers") + with Pool( + processes=num_workers, + initializer=_init_tokenizer, + ) as pool: + lengths = list( + tqdm( + pool.imap(_token_length, texts, chunksize=128), + total=len(texts), + desc=desc, + ) + ) + return lengths + + +BINS = [0, 16, 32, 48, 64, 80, 96, 128, 256, 512, float("inf")] + + +def bucket_of(length: int, bins=BINS) -> int: + for i in range(len(bins) - 1): + if bins[i] <= length < bins[i + 1]: + return i + raise RuntimeError(f"Length {length} did not fall into any bucket.") + + +def print_stats_and_extreme_bucket_samples(name, samples, lens): + arr = np.array(lens) + print(f"\n=== {name} ===") + print(f"样本数: {len(arr)}") + print(f"min: {arr.min()}") + print(f"max: {arr.max()}") + print(f"mean: {arr.mean():.2f}") + for p in [50, 75, 90, 95, 99, 99.9]: + print(f"p{p}: {np.percentile(arr, p):.2f}") + + hist, bin_edges = np.histogram(arr, bins=BINS) + print("区间分布(左闭右开,最后一档右闭,>最后边界的不会计入这里):") + for i, cnt in enumerate(hist): + print(f"[{bin_edges[i]:>3.0f}, {bin_edges[i+1]:>3.0f}): {cnt}") + + # 找到 min/max 对应的桶 + min_len = int(arr.min()) + max_len = int(arr.max()) + min_bucket = bucket_of(min_len, BINS) + max_bucket = bucket_of(max_len, BINS) + + def bucket_str(idx: int) -> str: + if idx < len(BINS) - 1: + return f"[{BINS[idx]}, {BINS[idx+1]})" + else: + return f"[{BINS[idx]}, +inf)" + + cnt = 0 + print(f"\n>>> {name} 最小桶 {min_bucket} 区间 {bucket_str(min_bucket)} 的样本:") + for length, sample in zip(lens, samples): + if cnt < 5 and bucket_of(length, BINS) == min_bucket: + print(f"len={length}\taudio={sample['audio']}\tcaption={sample['caption']}") + cnt += 1 + + cnt = 0 + print(f"\n>>> {name} 最大桶 {max_bucket} 区间 {bucket_str(max_bucket)} 的样本:") + for length, sample in zip(lens, samples): + if cnt < 5 and bucket_of(length, BINS) == max_bucket: + print(f"len={length}\taudio={sample['audio']}\tcaption={sample['caption']}") + cnt += 1 + + +if __name__ == "__main__": + short_lens = compute_lengths_mp(short_texts, desc="Tokenizing short_captions") + long_lens = compute_lengths_mp(long_texts, desc="Tokenizing long_captions") + + print_stats_and_extreme_bucket_samples("short_captions", short_samples, short_lens) + print_stats_and_extreme_bucket_samples("long_captions", long_samples, long_lens) diff --git a/egs/emilia/CLAP/local/filter_and_select_captions.py b/egs/emilia/CLAP/local/filter_and_select_captions.py new file mode 100644 index 0000000000..9f25629e15 --- /dev/null +++ b/egs/emilia/CLAP/local/filter_and_select_captions.py @@ -0,0 +1,131 @@ +import argparse +import logging +import os +import re +from collections import Counter +from pathlib import Path + +import lhotse +from lhotse import CutSet, load_manifest_lazy + +MULTI_SPEAKER_PATTERN = re.compile( + r"\b(speakers|first speaker|second speaker)\b", + re.IGNORECASE, +) + +NO_PATTERN = re.compile( + r"\bno\b|\bnot\b|\bneither\b|\bnor\b|\bfree from\b|\bwith no\b|\bwithout\b|\blacking\b|\brather than\b", + re.IGNORECASE, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--cuts_path", type=Path, help="Path to the input cuts list file." + ) + parser.add_argument("--tasks", type=Path, help="Path to the input task list file.") + parser.add_argument( + "--output_dir", + type=Path, + default=Path("./data/manifests"), + help="Path to the output directory", + ) + + return parser.parse_args() + + +def validate_short_captions(short_captions, cut_id, min_len=32): + assert len(short_captions) in ( + 1, + 2, + ), f"short_captions length must be 1 or 2, got {len(short_captions)}" + + for idx, caption in enumerate(short_captions): + if len(caption) < min_len: + logging.info( + f"Filtered cut (id={cut_id}): " + f"short caption[{idx}] too short (len={len(caption)})" + ) + return False + + return True + + +def validate_long_captions(long_captions, cut_id): + for idx, caption in enumerate(long_captions): + if MULTI_SPEAKER_PATTERN.search(caption): + logging.info(f"Filtered cut (id={cut_id}): multi speaker detected") + return False + + return True + + +def filter_long_captions(long_captions): + long_captions = [caption for caption in long_captions if "\n" not in caption] + long_captions = sorted(long_captions, key=lambda x: len(x)) + long_captions = [caption for caption in long_captions if 128 <= len(caption) <= 768] + long_captions = [ + caption for caption in long_captions if not NO_PATTERN.search(caption) + ] + return long_captions + + +def main(): + args = get_parser() + + num_long_captions = [] + + if args.tasks is not None: + cuts_paths = [Path(line) for line in args.tasks.read_text().splitlines()] + else: + cuts_paths = [args.cuts_path] + + for cuts_path in cuts_paths: + output_path = ( + args.output_dir + / f"{cuts_path.name.replace(''.join(cuts_path.suffixes), '')}-selected.jsonl.gz" + ) + if os.path.exists(output_path): + print(f"{output_path} exists, skip") + return + + cuts = load_manifest_lazy(cuts_path) + logging.info(f"Loading manifest: {cuts_path}") + + filtered_cuts = [] + for cut in cuts: + short_captions = cut.supervisions[0].short_captions + if not validate_short_captions(short_captions, cut.id): + continue + + long_captions = cut.supervisions[0].long_captions + if not validate_long_captions(long_captions, cut.id): + continue + long_captions = filter_long_captions(long_captions) + if not long_captions: + continue + + cut.supervisions[0].long_captions = long_captions + + filtered_cuts.append(cut) + num_long_captions.append(len(long_captions)) + + filtered_cuts = CutSet.from_cuts(filtered_cuts) + logging.info(f"Saving to {output_path}") + filtered_cuts.to_file(output_path) + + long_counter = Counter(num_long_captions) + print("Number of long captions distribution:") + for count in sorted(long_counter.keys()): + print(f"Length={count}, count={long_counter[count]}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/local/generate_cremad_manifests.py b/egs/emilia/CLAP/local/generate_cremad_manifests.py new file mode 100644 index 0000000000..00b455b5f9 --- /dev/null +++ b/egs/emilia/CLAP/local/generate_cremad_manifests.py @@ -0,0 +1,131 @@ +import argparse +import glob +import json +import logging +import os + +import torch +from lhotse import CutSet +from lhotse.audio import Recording +from lhotse.cut import MonoCut +from lhotse.supervision import SupervisionSegment + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +EMOTIONS = [ + "H", + "S", + "A", + "F", + "D", + "N", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--dataset-dir", + type=str, + help="Path to the cremad dataset", + default="./download/cremad", + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/manifests", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dataset_dir = args.dataset_dir + manifest_dir = args.manifest_dir + os.makedirs(manifest_dir, exist_ok=True) + + speaker_id2age = {} + with open(f"{dataset_dir}/VideoDemographics.csv", "r") as f: + next(f) + for line in f: + line = line.strip() + parts = line.split(",") + speaker_id = parts[0] + age = int(parts[1]) + speaker_id2age[speaker_id] = age + + for split in [ + "test", + # "valid", + # "train", + ]: + dataset = {} + + label_paths = sorted(glob.glob(f"{dataset_dir}/{split}/*.json")) + for label_path in label_paths: + with open(label_path, "r") as f: + item = json.load(f) + emotion = item["label"] + + audio_path = label_path.replace(".json", ".wav") + assert os.path.isfile(audio_path) + audio_name = audio_path.split("/", 1)[-1].replace(".wav", "") + speaker_id = audio_name.rsplit("/", 1)[-1].split("_", 1)[0] + age = speaker_id2age[speaker_id] + + dataset[audio_name] = [audio_path, speaker_id, age, emotion] + + logging.info(f"A total of {len(dataset)} clips!") + + cuts = [] + for i, (cut_id, info) in enumerate(dataset.items()): + audio_path, speaker_id, age, emotion = info + recording = Recording.from_file(audio_path, cut_id) + cut = MonoCut( + id=cut_id, + start=0, + duration=recording.duration, + channel=0, + recording=recording, + ) + supervision = SupervisionSegment( + id=cut_id, + recording_id=cut.recording.id, + start=0, + channel=0, + duration=cut.duration, + text="", + speaker=speaker_id, + ) + supervision.age = age + supervision.emotion = emotion + + cut.supervisions = [supervision] + cut = cut.resample(16000) + + cuts.append(cut) + + if i % 100 == 0 and i: + logging.info(f"Processed {i} cuts until now.") + + cuts = CutSet.from_cuts(cuts) + + manifest_output_dir = manifest_dir + "/" + f"cremad_cuts_{split}.jsonl.gz" + + logging.info(f"Storing the manifest to {manifest_output_dir}") + cuts.to_jsonl(manifest_output_dir) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/local/generate_iemocap_manifests.py b/egs/emilia/CLAP/local/generate_iemocap_manifests.py new file mode 100644 index 0000000000..18a3364832 --- /dev/null +++ b/egs/emilia/CLAP/local/generate_iemocap_manifests.py @@ -0,0 +1,125 @@ +import argparse +import glob +import logging +import os + +import torch +from lhotse import CutSet +from lhotse.audio import Recording +from lhotse.cut import MonoCut +from lhotse.supervision import SupervisionSegment + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +EMOTIONS = ["ang", "hap", "neu", "exc", "sad"] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--dataset-dir", + type=str, + help="Path to the iemocap dataset", + default="./download/IEMOCAP", + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/manifests", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dataset_dir = args.dataset_dir + manifest_dir = args.manifest_dir + os.makedirs(manifest_dir, exist_ok=True) + + for session_id in [1, 2, 3, 4, 5]: + wav_folder = f"{dataset_dir}/Session{session_id}/dialog/wav" + label_folder = f"{dataset_dir}/Session{session_id}/dialog/EmoEvaluation" + + label_files = sorted(glob.glob(f"{label_folder}/Ses*.txt")) + + dataset = {} + + for label in label_files: + with open(label, "r") as f: + data = f.readlines() + + for line in data: + # skip lines + if line[0] != "[": + continue + items = line.strip().split("\t") + timestamp = items[0].replace("[", "").replace("]", "").split() + timestamp = [float(timestamp[0]), float(timestamp[2])] + clip_name = items[1] + audio_name = clip_name.rsplit("_", 1)[0] + emotion = items[2] + audio_name = wav_folder + "/" + f"{audio_name}.wav" + + assert os.path.isfile(audio_name) + assert clip_name not in dataset + + dataset[clip_name] = [audio_name, timestamp, emotion] + + logging.info(f"A total of {len(dataset)} clips!") + + cuts = [] + for i, (cut_id, info) in enumerate(dataset.items()): + audio_file, timestamp, emotion = info + recording = Recording.from_file(audio_file, cut_id) + if emotion not in EMOTIONS: + continue + if emotion == "exc": + emotion = "hap" + assert recording.sampling_rate == 16000 + cut = MonoCut( + id=cut_id, + start=timestamp[0], + duration=timestamp[1] - timestamp[0], + channel=0, + recording=recording, + ) + supervision = SupervisionSegment( + id=cut_id, + recording_id=cut.recording.id, + start=0.0, + channel=0, + duration=cut.duration, + text="", + ) + supervision.emotion = emotion + + cut.supervisions = [supervision] + cuts.append(cut) + + if i % 100 == 0 and i: + logging.info(f"Processed {i} cuts until now.") + + logging.info(f"After filtering, a total of {len(cuts)} valid samples.") + cuts = CutSet.from_cuts(cuts) + + manifest_output_dir = ( + manifest_dir + "/" + f"iemocap_cuts_session{session_id}.jsonl.gz" + ) + + logging.info(f"Storing the manifest to {manifest_output_dir}") + cuts.to_jsonl(manifest_output_dir) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/local/generate_paraspeechcaps_manifests.py b/egs/emilia/CLAP/local/generate_paraspeechcaps_manifests.py new file mode 100644 index 0000000000..e0a3d66612 --- /dev/null +++ b/egs/emilia/CLAP/local/generate_paraspeechcaps_manifests.py @@ -0,0 +1,245 @@ +import argparse +import json +import logging +import os +import re +import tarfile + +from lhotse import CutSet +from lhotse.audio import Recording +from lhotse.cut import MonoCut +from lhotse.supervision import SupervisionSegment +from normalize_paraspeechcaps_short_captions import normalize + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--output-dir", + type=str, + default="data/manifests", + ) + + return parser.parse_args() + + +def process_psc_base(args, subset, source): + manifests_file = f"{args.output_dir}/paraspeechcaps_{subset}-{source}.jsonl" + output_path = ( + args.output_dir + "/" + f"paraspeechcaps_cuts_{subset}-{source}.jsonl.gz" + ) + + if os.path.exists(output_path): + print(f"{output_path} exists, skip") + return + + cuts = [] + num_cuts = 0 + + logging.info(f"Loading manifest: {manifests_file}") + with open(manifests_file) as reader: + for line in reader: + item = json.loads(line) + + audio_path = item["audio_path"] + + speaker = item["speaker"].strip() + gender = item["gender"].strip() + accent = item["accent"].strip() + pitch = item["pitch"].strip() + speaking_rate = item["speaking_rate"].strip() + intrinsic_tags = [i.strip() for i in item["intrinsic_tags"]] + situational_tags = ( + [i.strip() for i in item["situational_tags"]] + if item["situational_tags"] is not None + else [] + ) + + transcription = item["text"].strip() + short_captions = [ + normalize(re.sub(r"[\t\n\r]", " ", i).strip(), accent) + for i in item["caption"] + ] + + cut_id = ( + subset + + "-" + + source + + "-" + + audio_path.replace("download/", "") + .replace("/", "-") + .replace(".wav", "") + ) + + if not os.path.exists(audio_path): + logging.warning(f"No such file: {audio_path}") + continue + + recording = Recording.from_file(audio_path, cut_id) + cut = MonoCut( + id=cut_id, + start=0.0, + duration=recording.duration, + channel=0, + recording=recording, + ) + + supervision = SupervisionSegment( + id=recording.id, + recording_id=recording.id, + start=0.0, + channel=0, + duration=recording.duration, + text=transcription, + speaker=speaker, + ) + supervision.short_captions = short_captions + supervision.long_captions = [] + + supervision.gender = gender + supervision.accent = accent + supervision.pitch = pitch + supervision.speaking_rate = speaking_rate + supervision.intrinsic_tags = intrinsic_tags + supervision.situational_tags = situational_tags + + cut.supervisions = [supervision] + cut = cut.resample(16000) + cuts.append(cut) + + num_cuts += 1 + if num_cuts % 100 == 0 and num_cuts: + logging.info(f"Processed {num_cuts} cuts until now.") + + cut_set = CutSet.from_cuts(cuts) + + logging.info(f"Saving to {output_path}") + cut_set.to_file(output_path) + + +def process_psc_scaled(args, subset, source): + manifests_file = f"{args.output_dir}/paraspeechcaps_{subset}-{source}.jsonl" + output_path = ( + args.output_dir + "/" + f"paraspeechcaps_cuts_{subset}-{source}.jsonl.gz" + ) + + if os.path.exists(output_path): + print(f"{output_path} exists, skip") + return + + items = [] + logging.info(f"Loading manifest: {manifests_file}") + with open(manifests_file) as reader: + for line in reader: + items.append(json.loads(line)) + + def extract_key(audio_path: str) -> str: + return "".join(re.search(r"EN_B(\d+)_S\d+(\d)", audio_path).groups()) + + items_sorted = sorted(items, key=lambda x: int(extract_key(x["audio_path"]))) + + audio_output_dir = "./download/Emilia-audio" + os.makedirs(audio_output_dir, exist_ok=True) + + cuts = [] + num_cuts = 0 + current_tar_key = None + current_tar_handle = None + for item in items_sorted: + audio_path_in_tar = item["audio_path"].rsplit("/", 1)[-1] + transcription = item["text"].strip() + assert len(item["caption"]) == 1, item["caption"] + short_captions = [re.sub(r"[\t\n\r]", " ", item["caption"][0]).strip()] + + tar_key = extract_key(audio_path_in_tar) + while True: + if tar_key != current_tar_key: + if current_tar_handle: + current_tar_handle.close() + tar_path = f"./download/Emilia/EN/EN-B{tar_key}.tar" + logging.info(f"About to open tar: {tar_path}") + current_tar_handle = tarfile.open(tar_path, "r") + current_tar_key = tar_key + + audio_path = os.path.join(audio_output_dir, audio_path_in_tar) + try: + with open(audio_path, "wb") as f: + f.write(current_tar_handle.extractfile(audio_path_in_tar).read()) + break + except: + logging.warning( + f"KeyError: filename {audio_path_in_tar} not found in {tar_path}" + ) + tar_key = f"{tar_key[:-1] + str((int(tar_key[-1]) + 9) % 10)}" + continue + + cut_id = subset + "-" + source + "-" + audio_path_in_tar.replace(".mp3", "") + + recording = Recording.from_file(audio_path, cut_id) + cut = MonoCut( + id=cut_id, + start=0.0, + duration=recording.duration, + channel=0, + recording=recording, + ) + + supervision = SupervisionSegment( + id=recording.id, + recording_id=recording.id, + start=0.0, + channel=0, + duration=recording.duration, + text=transcription, + ) + supervision.short_captions = short_captions + supervision.long_captions = [] + + cut.supervisions = [supervision] + cut = cut.resample(16000) + cuts.append(cut) + + num_cuts += 1 + if num_cuts % 100 == 0 and num_cuts: + logging.info(f"Processed {num_cuts} cuts until now.") + + if current_tar_handle: + current_tar_handle.close() + + cut_set = CutSet.from_cuts(cuts) + + logging.info(f"Saving to {output_path}") + cut_set.to_file(output_path) + + +def main(): + args = get_parser() + os.makedirs(args.output_dir, exist_ok=True) + + split2subsets = { + "psc-base": ["test", "dev", "holdout", "train_base"], + # "psc-scaled": ["train_scaled"], + } + + split2sources = { + "psc-base": ["voxceleb", "expresso", "ears"], + "psc-scaled": ["emilia"], + } + + for split, subsets in split2subsets.items(): + for subset in subsets: + for source in split2sources[split]: + if split == "psc-base": + process_psc_base(args, subset, source) + elif split == "psc-scaled": + process_psc_scaled(args, subset, source) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/local/generate_paraspeechcaps_test_manifests.py b/egs/emilia/CLAP/local/generate_paraspeechcaps_test_manifests.py new file mode 100644 index 0000000000..ffc8cd9b2d --- /dev/null +++ b/egs/emilia/CLAP/local/generate_paraspeechcaps_test_manifests.py @@ -0,0 +1,43 @@ +from collections import Counter + +from lhotse import CutSet + + +def main(): + for split in ["voxceleb", "expresso", "ears"]: + test_cuts = CutSet.from_file( + f"data/manifests/paraspeechcaps_cuts_test-{split}.jsonl.gz" + ) + + test_sources = [] + for cut in test_cuts: + test_sources.append(cut.recording.sources[0].source) + + counter = Counter(test_sources) + duplicates = [k for k, v in counter.items() if v > 1] + + print(f"Found duplicated audio samples: {duplicates} from test cuts.") + + test_sources = set(test_sources) + + print(f"Collected {len(test_sources)} unique sources from test cuts.") + + holdout_cuts = CutSet.from_file( + f"data/manifests/paraspeechcaps_cuts_holdout-{split}.jsonl.gz" + ) + + filtered_cuts = CutSet.from_cuts( + cut + for cut in holdout_cuts + if cut.recording.sources[0].source in test_sources + ) + + print(f"Filtered cuts: {len(filtered_cuts)} remaining from holdout set.") + + filtered_cuts.to_file( + f"data/manifests/paraspeechcaps_cuts_test-{split}.jsonl.gz" + ) + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/local/generate_ravdess_manifests.py b/egs/emilia/CLAP/local/generate_ravdess_manifests.py new file mode 100644 index 0000000000..fad2479dcb --- /dev/null +++ b/egs/emilia/CLAP/local/generate_ravdess_manifests.py @@ -0,0 +1,124 @@ +import argparse +import glob +import json +import logging +import os + +import torch +from lhotse import CutSet +from lhotse.audio import Recording +from lhotse.cut import MonoCut +from lhotse.supervision import SupervisionSegment + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +EMOTIONS = [ + "angry", + "calm", + "disgust", + "fearful", + "happy", + "sad", + "surprised", + "neutral", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--dataset-dir", + type=str, + help="Path to the ravdess dataset", + default="./download/ravdess", + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/manifests", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dataset_dir = args.dataset_dir + manifest_dir = args.manifest_dir + os.makedirs(manifest_dir, exist_ok=True) + + for fold_id in [0, 1, 2, 3]: + dataset = {} + + label_paths = sorted(glob.glob(f"{dataset_dir}/fold_{fold_id}/*.json")) + for label_path in label_paths: + with open(label_path, "r") as f: + item = json.load(f) + emotion = item["emotion"] + + audio_path = label_path.replace(".json", ".wav") + assert os.path.isfile(audio_path) + audio_name = audio_path.split("/", 2)[-1].replace(".wav", "") + + speaker_id = int(audio_name.rsplit("-", 1)[-1]) + if speaker_id % 2 == 0: + gender = "female" + else: + gender = "male" + + dataset[audio_name] = [audio_path, gender, emotion] + + logging.info(f"A total of {len(dataset)} clips!") + + cuts = [] + for i, (cut_id, info) in enumerate(dataset.items()): + audio_path, gender, emotion = info + recording = Recording.from_file(audio_path, cut_id) + cut = MonoCut( + id=cut_id, + start=0, + duration=recording.duration, + channel=0, + recording=recording, + ) + supervision = SupervisionSegment( + id=cut_id, + recording_id=cut.recording.id, + start=0, + channel=0, + duration=cut.duration, + text="", + gender=gender, + ) + supervision.emotion = emotion + + cut.supervisions = [supervision] + cut = cut.resample(16000) + + cuts.append(cut) + + if i % 100 == 0 and i: + logging.info(f"Processed {i} cuts until now.") + + cuts = CutSet.from_cuts(cuts) + + manifest_output_dir = ( + manifest_dir + "/" + f"ravdess_cuts_fold{fold_id}.jsonl.gz" + ) + + logging.info(f"Storing the manifest to {manifest_output_dir}") + cuts.to_jsonl(manifest_output_dir) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/local/normalize_paraspeechcaps_short_captions.py b/egs/emilia/CLAP/local/normalize_paraspeechcaps_short_captions.py new file mode 100644 index 0000000000..e8ac570d1f --- /dev/null +++ b/egs/emilia/CLAP/local/normalize_paraspeechcaps_short_captions.py @@ -0,0 +1,171 @@ +import re + +import regex + + +def remove_brackets(text: str) -> str: + # 删除 (Or) 以及之后的所有内容 + text = re.sub(r"[\(\[\{\<]or[\)\]\}\>].*", "", text, flags=re.I) + + # 去括号及内容 + pattern = re.compile(r"\([^()]*\)|\[[^[\]]*]|\{[^{}]*\}|<[^<>]*>") + while True: + new_text = pattern.sub("", text) + if new_text == text: + break + text = new_text + + # 清理残留的单个括号符号 + text = re.sub(r"[()\[\]{}<>]", "", text) + + # 删除 Note: | Or: | Description: 及之后的所有内容 + text = re.sub(r"\b(note|or|description):.*", "", text, flags=re.I) + + # 删除冒号之前所有内容 + text = re.sub(r"^.*?:\s*", "", text) + + return text + + +def map_phrases(text: str) -> str: + text = re.sub(r"\bwomen's\b", "woman's", text, flags=re.I) + text = re.sub(r"\bmen's\b", "man's", text, flags=re.I) + text = re.sub(r"\bwomen\b", "woman", text, flags=re.I) + text = re.sub(r"\bmen\b", "man", text, flags=re.I) + text = re.sub(r"\b([\w-]+)\s+in\s+origin\b", r"\1 accent", text) + text = re.sub(r"\borigin\b", "accent", text) + text = re.sub(r"\bcontinent\b", "accent", text) + text = re.sub(r"\baccents\b", "accent", text) + text = re.sub(r"aussie", "Australian", text, flags=re.I) + text = text.replace("environs", "environment") + text = re.sub(r"\s*,?\s*however\s*,?\s*", " ", text, flags=re.I) + + text = re.sub(r"\s+([,.;!?])", r"\1", text) # 标点前空格 + text = re.sub(r"\s{2,}", " ", text) # 连续空格 + text = re.sub(r"^[,.;!?]+\s*", "", text) # 标点开头 + text = text.replace(",,", ",") + text = text.replace(",.", ".") + text = text.replace(".,", ".") + text = text.replace("..", ".") + text = text.strip() + + # 修正每个句首大小写 + text = re.sub( + r"(^|[.!?]\s+)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text + ) + + return text + + +def process_accent(text: str, accent: str) -> str: + if len(accent) == 0: + return text + + def to_display_form(w: str) -> str: + w = w.lower() + parts = re.split(r"([/\-\s])", w) + return "".join(p.capitalize() if p.isalpha() else p for p in parts).strip() + + if "/" in accent: + accent = [accent] + accent.split("/") + else: + accent = [accent] + + is_missing = True + for w in accent: + + base = w.lower() + display = to_display_form(w) + + exact_pattern = re.compile(rf"\b{re.escape(base)}\b", re.I) + m = exact_pattern.search(text) + + if not m: + max_edit = min(2, max(0, len(w.replace(" ", "")) - 5)) + fuzzy_pattern = regex.compile( + rf"(?i)\b({regex.escape(base)}){{e<={max_edit}}}\b" + ) + m = fuzzy_pattern.search(text) + + if m: + matched_text = m.group() + + if " " not in base and " " in matched_text.strip(): + continue + + span = m.span() + prefix = " " if matched_text.startswith(" ") else "" + suffix = " " if matched_text.endswith(" ") else "" + text = text[: span[0]] + prefix + display + suffix + text[span[1] :] + + is_missing = False + + if is_missing: + text, count = re.subn(r"(?i)\b\w+\s+accent\b", f"{display} accent", text) + if count == 0: + text += f" {display} accent." + + return text + + +def normalize(text: str, accent: str) -> str: + text = remove_brackets(text) + text = map_phrases(text) + text = process_accent(text, accent) + return text + + +def _normalize(ori_text: str, accent: str) -> tuple[str, str]: + norm_text = remove_brackets(ori_text) + norm_text = map_phrases(norm_text) + norm_text = process_accent(norm_text, accent) + return ori_text, norm_text + + +if __name__ == "__main__": + import difflib + import sys + from multiprocessing import Pool + + RED = "\033[31m" + GREEN = "\033[32m" + RESET = "\033[0m" + + def color_diff_ori(ori_text: str, norm_text: str) -> str: + sm = difflib.SequenceMatcher(a=ori_text, b=norm_text, autojunk=False) + out = [] + + for tag, i1, i2, j1, j2 in sm.get_opcodes(): + # tag: 'equal', 'replace', 'delete', 'insert' + if tag == "equal": + out.append(ori_text[i1:i2]) + elif tag in ("replace", "delete"): + out.append(f"{RED}{ori_text[i1:i2]}{RESET}") + elif tag == "insert": + out.append(f"{GREEN}{norm_text[j1:j2]}{RESET}") + + return "".join(out) + + input_path = sys.argv[1] + success_path = sys.argv[2] + badcase_path = sys.argv[3] + + with open(input_path, "r", encoding="utf-8") as f: + lines = [line.rstrip("\n").rsplit(" ", 1) for line in f] + + with Pool(processes=64) as pool: + normalized = pool.starmap(_normalize, lines) + + with Pool(processes=64) as pool: + colored = pool.starmap(color_diff_ori, normalized) + + with open(success_path, "w", encoding="utf-8") as f_success, open( + badcase_path, "w", encoding="utf-8" + ) as f_bad: + + for (ori_text, norm_text), diff_line in zip(normalized, colored): + if ori_text != norm_text: + f_success.write(diff_line + "\n") + # f_success.write(norm_text + "\n") + else: + f_bad.write(ori_text + "\n") diff --git a/egs/emilia/CLAP/local/select_paraspeechcaps_test_manifests.py b/egs/emilia/CLAP/local/select_paraspeechcaps_test_manifests.py new file mode 100644 index 0000000000..783d4a755c --- /dev/null +++ b/egs/emilia/CLAP/local/select_paraspeechcaps_test_manifests.py @@ -0,0 +1,36 @@ +from lhotse import CutSet + + +def main(): + selected_cuts = [] + signatures = set() + + for split in ["voxceleb", "expresso", "ears"]: + holdout_cuts = CutSet.from_file( + f"data/manifests/paraspeechcaps_cuts_holdout-{split}.jsonl.gz" + ) + + for cut in holdout_cuts: + sup = cut.supervisions[0] + custom = sup.custom + + gender = sup.gender + accent = custom.get("accent") + pitch = custom.get("pitch") + speaking_rate = custom.get("speaking_rate") + + situational_tags = custom.get("situational_tags", []) + situational_tags = frozenset(situational_tags) + + signature = (gender, accent, pitch, speaking_rate, situational_tags) + + if signature not in signatures: + signatures.add(signature) + selected_cuts.append(cut) + + selected_cuts = CutSet.from_cuts(selected_cuts) + selected_cuts.to_file(f"data/manifests/paraspeechcaps_cuts_test.jsonl.gz") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/paraclap/clap_datamodule.py b/egs/emilia/CLAP/paraclap/clap_datamodule.py new file mode 120000 index 0000000000..1ab77496dd --- /dev/null +++ b/egs/emilia/CLAP/paraclap/clap_datamodule.py @@ -0,0 +1 @@ +../glap/clap_datamodule.py \ No newline at end of file diff --git a/egs/emilia/CLAP/paraclap/evaluate.sh b/egs/emilia/CLAP/paraclap/evaluate.sh new file mode 100755 index 0000000000..622da17ebc --- /dev/null +++ b/egs/emilia/CLAP/paraclap/evaluate.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +export PYTHONPATH=/root/icefall:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=$1 + +md=800 + +exp_dir=paraclap/exp + +echo $exp_dir + +if false; then +python paraclap/evaluate_retrieval.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +fi + +if true; then +python paraclap/evaluate_zero_shot_classification.py \ + --manifest-dir data/manifests \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +fi + +# python /root/busygpu/run.py & diff --git a/egs/emilia/CLAP/paraclap/evaluate_retrieval.py b/egs/emilia/CLAP/paraclap/evaluate_retrieval.py new file mode 100755 index 0000000000..d3653d05b9 --- /dev/null +++ b/egs/emilia/CLAP/paraclap/evaluate_retrieval.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import torch +from clap_datamodule import DataModule +from model import CLAP +from transformers import AutoTokenizer + +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + + return params + + +def evaluate( + params: AttributeDict, + model: Any, + tokenizer: AutoTokenizer, + device: torch.device, + test_dl: torch.utils.data.DataLoader, + caption_type: str, + return_details: bool = False, +) -> Dict[str, float]: + """Run the Speech-Text Retrieval evaluation process.""" + metrics = {} + num_samples = 0 + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + eval_detail = { + "all_audio_paths": [], + "all_texts": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + + if caption_type == "short_captions": + captions = [c.supervisions[0].short_captions[0] for c in batch["cuts"]] + elif caption_type == "long_captions": + captions = [c.supervisions[0].long_captions[-1] for c in batch["cuts"]] + else: + raise ValueError + + text = tokenizer( + captions, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + text_features, audio_features, logit_scale = model( + audio=audio, + text=text, + ) + + num_samples += audio_features.shape[0] + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + if return_details: + eval_detail["all_audio_paths"].extend( + [c.recording.sources[0].source for c in batch["cuts"]] + ) + eval_detail["all_texts"].extend(captions) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + all_audio_features = torch.cat(eval_info["all_audio_features"]) + all_text_features = torch.cat(eval_info["all_text_features"]) + metrics_single_dataset, details_single_dataset = compute_metrics( + audio_features=all_audio_features, + text_features=all_text_features, + logit_scale=logit_scale.cpu(), + ) + metrics.update(metrics_single_dataset) + + if return_details: + details = {} + for k, ranks in details_single_dataset.items(): + if k == "audio_to_text_ranks": + src_list = eval_detail["all_audio_paths"] + tgt_list = eval_detail["all_texts"] + elif k == "text_to_audio_ranks": + src_list = eval_detail["all_texts"] + tgt_list = eval_detail["all_audio_paths"] + else: + raise ValueError + + details[k] = { + src_list[i]: [ + f"GT# {tgt_list[j]}" if j == i else tgt_list[j] for j in ranking + ] + for i, ranking in enumerate(ranks) + } + + result_dict = {"metrics": metrics} + if return_details: + result_dict["details"] = details + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + audio_features = audio_features / torch.norm(audio_features, dim=-1, keepdim=True) + text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True) + + logits_per_text = logit_scale * (text_features @ audio_features.t()) + logits_per_audio = logits_per_text.t() + + metrics = {} + metrics["num_samples"] = N + + details = {} + + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + # preds = torch.where(ranking == ground_truth)[1] + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + details[f"{name}_ranks"] = ranking.detach().cpu().tolist() + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics, details + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "speech-text-retrieval" + + setup_logger(f"{params.res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + ckpt = torch.hub.load_state_dict_from_url( + url="https://huggingface.co/KeiKinn/paraclap/resolve/main/best.pth.tar?download=true", + map_location="cpu", + check_hash=True, + ) + text_model = "bert-base-uncased" + tokenizer = AutoTokenizer.from_pretrained(text_model) + audio_model = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + model = CLAP( + speech_name=audio_model, + text_name=text_model, + embedding_dim=768, + ) + model.load_state_dict(ckpt) + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + paraspeechcaps_test_cuts = datamodule.paraspeechcaps_test_cuts() + paraspeechcaps_test_dl = datamodule.test_dataloaders(paraspeechcaps_test_cuts) + + test_sets = [ + "paraspeechcaps_test", + ] + test_dls = [ + paraspeechcaps_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + tokenizer=tokenizer, + device=device, + test_dl=test_dl, + caption_type="long_captions", + return_details=True, + ) + metrics = result_dict["metrics"] + details = result_dict["details"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + with open(f"{params.res_dir}/details-decode", "w", encoding="utf-8") as f: + json.dump(details, f, ensure_ascii=False, indent=2) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/paraclap/evaluate_zero_shot_classification.py b/egs/emilia/CLAP/paraclap/evaluate_zero_shot_classification.py new file mode 100755 index 0000000000..5ffccd72da --- /dev/null +++ b/egs/emilia/CLAP/paraclap/evaluate_zero_shot_classification.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from typing import Any, Dict + +import torch +from clap_datamodule import DataModule +from model import CLAP +from transformers import AutoTokenizer + +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + + return params + + +def map_iemocap_emotion_label_to_index(label: str) -> int: + label_map = { + "hap": 0, + "ang": 1, + "sad": 2, + "neu": 3, + } + return label_map[label] + + +def map_ravdess_emotion_label_to_index(label: str) -> int: + label_map = { + "angry": 0, + "calm": 1, + "disgust": 2, + "fearful": 3, + "happy": 4, + "sad": 5, + "surprised": 6, + "neutral": 7, + } + return label_map[label] + + +def map_ravdess_gender_label_to_index(label: str) -> int: + label_map = { + "male": 0, + "female": 1, + } + return label_map[label] + + +def map_cremad_emotion_label_to_index(label: str) -> int: + label_map = { + "H": 0, + "S": 1, + "A": 2, + "F": 3, + "D": 4, + "N": 5, + } + return label_map[label] + + +def map_cremad_age_label_to_index(label: str) -> int: + if label < 20: + index = 0 + elif label < 40: + index = 1 + elif label < 60: + index = 2 + else: + index = 3 + return index + + +def generate_iemocap_emotion_prompts() -> str: + return [ + "this person is feeling happy.", + "this person is feeling angry.", + "this person is feeling sad.", + "this person is feeling neutral.", + ] + + +def generate_ravdess_emotion_prompts() -> str: + return [ + "angry", + "calm", + "disgust", + "fear", + "happy", + "sad", + "surprised", + "neutral", + ] + + +def generate_ravdess_gender_prompts() -> str: + return [ + "male", + "female", + ] + + +def generate_cremad_emotion_prompts() -> str: + return [ + "happy", + "sad", + "angry", + "fear", + "disgust", + "neutral", + ] + + +def generate_cremad_age_prompts() -> str: + return [ + "teenager", + "young adult", + "middle-aged", + "older", + ] + + +def evaluate( + params: AttributeDict, + model: Any, + tokenizer: AutoTokenizer, + device: torch.device, + test_set: str, + test_dl: torch.utils.data.DataLoader, +) -> Dict[str, float]: + """Run the Zero-Shot Classification evaluation process.""" + metrics = {} + eval_info = { + "all_audio_features": [], + "all_gt_labels": [], + } + + if test_set == "iemocap_emotion": + prompts = generate_iemocap_emotion_prompts() + elif test_set == "ravdess_emotion": + prompts = generate_ravdess_emotion_prompts() + elif test_set == "ravdess_gender": + prompts = generate_ravdess_gender_prompts() + elif test_set == "cremad_emotion": + prompts = generate_cremad_emotion_prompts() + elif test_set == "cremad_age": + prompts = generate_cremad_age_prompts() + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + text = tokenizer( + prompts, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + text_features = model.forward_text_branch(text=text) + + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + audio = batch["audio"].to(device) + + if test_set == "iemocap_emotion": + gt_labels = [ + map_iemocap_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_emotion": + gt_labels = [ + map_ravdess_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "ravdess_gender": + gt_labels = [ + map_ravdess_gender_label_to_index(c.supervisions[0].gender) + for c in batch["cuts"] + ] + elif test_set == "cremad_emotion": + gt_labels = [ + map_cremad_emotion_label_to_index(c.supervisions[0].emotion) + for c in batch["cuts"] + ] + elif test_set == "cremad_age": + gt_labels = [ + map_cremad_age_label_to_index(c.supervisions[0].age) + for c in batch["cuts"] + ] + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + audio_features = model.forward_audio_branch(audio=audio) + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_gt_labels"].extend(gt_labels) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + all_audio_features = torch.cat(eval_info["all_audio_features"]) + all_text_features = text_features.cpu() + all_gt_labels = torch.tensor(eval_info["all_gt_labels"], dtype=torch.int64) + metrics_single_dataset = compute_metrics( + audio_features=all_audio_features, + text_features=all_text_features, + gt_labels=all_gt_labels, + test_set=test_set, + ) + metrics.update(metrics_single_dataset) + + result_dict = {"metrics": metrics} + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + gt_labels: torch.Tensor, + test_set: str, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + + audio_features = audio_features / torch.norm(audio_features, dim=-1, keepdim=True) + text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True) + + logits_per_text = torch.matmul(text_features, audio_features.t()) + logits_per_audio = logits_per_text.t() + preds = logits_per_audio.argmax(dim=1) + + wa = (preds == gt_labels).float().mean().item() + + recall_sum = 0.0 + num_classes = 0 + for cls_idx in torch.unique(gt_labels): + cls_idx = cls_idx.item() + cls_mask = gt_labels == cls_idx + recall = (preds[cls_mask] == cls_idx).float().mean().item() + recall_sum += recall + num_classes += 1 + logging.info(f"{test_set}: cls {cls_idx}, recall {recall}") + uar = recall_sum / num_classes if num_classes > 0 else 0.0 + + return {"wa": wa, "uar": uar} + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "zero-shot-classification" + + setup_logger(f"{params.res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + ckpt = torch.hub.load_state_dict_from_url( + url="https://huggingface.co/KeiKinn/paraclap/resolve/main/best.pth.tar?download=true", + map_location="cpu", + check_hash=True, + ) + text_model = "bert-base-uncased" + tokenizer = AutoTokenizer.from_pretrained(text_model) + audio_model = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + model = CLAP( + speech_name=audio_model, + text_name=text_model, + embedding_dim=768, + ) + model.load_state_dict(ckpt) + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + iemocap_test_cuts = datamodule.iemocap_cuts() + iemocap_test_dl = datamodule.test_dataloaders(iemocap_test_cuts) + + ravdess_test_cuts = datamodule.ravdess_cuts() + ravdess_test_dl = datamodule.test_dataloaders(ravdess_test_cuts) + + cremad_test_cuts = datamodule.cremad_cuts() + cremad_test_dl = datamodule.test_dataloaders(cremad_test_cuts) + + test_sets = [ + "iemocap_emotion", + "ravdess_emotion", + "cremad_emotion", + "ravdess_gender", + "cremad_age", + ] + test_dls = [ + iemocap_test_dl, + ravdess_test_dl, + cremad_test_dl, + ravdess_test_dl, + cremad_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + tokenizer=tokenizer, + device=device, + test_set=test_set, + test_dl=test_dl, + ) + metrics = result_dict["metrics"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/paraclap/model.py b/egs/emilia/CLAP/paraclap/model.py new file mode 100644 index 0000000000..280152d519 --- /dev/null +++ b/egs/emilia/CLAP/paraclap/model.py @@ -0,0 +1,80 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModel, Wav2Vec2Model + + +class Projection(torch.nn.Module): + def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(d_in, d_out, bias=False) + self.linear2 = torch.nn.Linear(d_out, d_out, bias=False) + self.layer_norm = torch.nn.LayerNorm(d_out) + self.drop = torch.nn.Dropout(p) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + embed1 = self.linear1(x) + embed2 = self.drop(self.linear2(F.gelu(embed1))) + embeds = self.layer_norm(embed1 + embed2) + return embeds + + +class SpeechEncoder(torch.nn.Module): + def __init__(self, model_name): + super().__init__() + self.model_name = model_name + self.base = Wav2Vec2Model.from_pretrained(self.model_name) + self.hidden_size = self.base.config.hidden_size + + def forward(self, x): + x = self.base(x)["last_hidden_state"] + x = x.mean(1) + return x + + +class TextEncoder(torch.nn.Module): + def __init__(self, model_name: str) -> None: + super().__init__() + self.base = AutoModel.from_pretrained(model_name) + + def forward(self, x): + out = self.base(**x)[0] + out = out[:, 0, :].detach() # get CLS token output + return out + + +class CLAP(torch.nn.Module): + def __init__(self, speech_name: str, text_name: str, embedding_dim: int = 1024): + super().__init__() + + self.audio_branch = SpeechEncoder(model_name=speech_name) + + self.text_branch = TextEncoder(model_name=text_name) + self.audio_projection = Projection(self.audio_branch.hidden_size, embedding_dim) + self.text_projection = Projection( + self.text_branch.base.config.hidden_size, embedding_dim + ) + + self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def forward(self, audio, text): + speech_emb = self.audio_branch(audio) + text_emb = self.text_branch(text) + + speech_emb = self.audio_projection(speech_emb) + text_emb = self.text_projection(text_emb) + + return text_emb, speech_emb, self.logit_scale.exp() + + def forward_audio_branch(self, audio): + speech_emb = self.audio_branch(audio) + speech_emb = self.audio_projection(speech_emb) + + return speech_emb + + def forward_text_branch(self, text): + text_emb = self.text_branch(text) + text_emb = self.text_projection(text_emb) + + return text_emb diff --git a/egs/emilia/CLAP/spear/at_datamodule.py b/egs/emilia/CLAP/spear/at_datamodule.py new file mode 100644 index 0000000000..86d5f661c8 --- /dev/null +++ b/egs/emilia/CLAP/spear/at_datamodule.py @@ -0,0 +1,1293 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from dataset_at import MultiTaskDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--features-mask-size", + type=int, + default=27, + help="The maximum mask bins along the frequency axis in specaug" + ) + + group.add_argument( + "--frames-mask-size", + type=int, + default=100, + help="The maximum mask length along the time axis in specaug" + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--enable-mixup", + type=str2bool, + default=True, + help="When enabled, select random cuts from balanced set and mix" + "Note the label with also be mixed", + ) + + group.add_argument( + "--mixup-prob", + type=float, + default=0.5, + help="The probability of doing mixup to a cut" + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # ASR related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="L", + ) + + group.add_argument( + "--repeat-wenetspeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-aishell", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + # KD related + group.add_argument( + "--mvq-KD", + type=str2bool, + default=False, + help="If load the codebook indexes instead of ground truth of audio events" + ) + + group.add_argument( + "--at-KD", + type=str2bool, + default=False, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + world_size = 1 + rank = 0 + + transforms = [] + assert not self.args.enable_musan + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz").drop_features() + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=self.args.features_mask_size, + num_feature_masks=2, + frames_mask_size=self.args.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, + ) + ) + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}, " + f"frames_mask_size: {self.args.frames_mask_size}, " + f"features_mask_size: {self.args.features_mask_size}" + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = MultiTaskDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + assert self.args.on_the_fly_feats + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + if self.args.enable_mixup: + mixup_cuts = load_manifest("data/fbank_as_ced_mAP50/audioset_cuts_balanced.jsonl.gz").drop_features() + else: + mixup_cuts = None + + train = MultiTaskDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + mixup_prob=self.args.mixup_prob, + mixup_cuts=mixup_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True, + ) + else: + assert len(cuts_train) == 1, f"The training cuts contain {len(cuts_train)} cutsets" + cuts_train = list(cuts_train.values())[0] + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + train_sampler = WeightedSimpleCutSampler( + cuts_train, + weights, + num_samples=self.args.at_num_samples, + max_duration=self.args.max_duration, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/dev", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + if self.args.use_shar: + medium_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/libriheavy/medium", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + if self.args.libriheavy_subset == "medium": + return medium_cuts + else: + assert self.args.libriheavy_subset == "large" + large_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/libriheavy/large", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = [medium_cuts, large_cuts] + return CutSet.mux( + *cuts, + weights=[1, 9], + stop_early=False, + ) + + else: + return load_manifest_lazy( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.libriheavy_subset}.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.wenetspeech_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def aishell_train_cuts(self) -> CutSet: + logging.info("About to get aishell training cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_train.jsonl.gz") + + @lru_cache() + def aishell_dev_cuts(self) -> CutSet: + logging.info("About to get aishell dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") + + @lru_cache() + def aishell_test_cuts(self) -> CutSet: + logging.info("About to get aishell test cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") + + @lru_cache() + def mls_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + if self.args.use_shar: + num_splits = 8 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/MLS/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ).resample(16000) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","common_voice_20200622","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone","speech_wav"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + # alimeeting_cuts, ali_dur, ali_num_cuts = self.alimeeting_cuts() + # all_cuts.append(alimeeting_cuts) + # cuts_duration.append(ali_dur) + # cuts_len.append(ali_num_cuts) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def alimeeting_cuts(self): + # alimeeting: 140 hrs, 186364 cuts + def reduce_supervisions(c): + supervisions = c.supervisions + supervisions = [supervisions[0]] + c.supervisions = supervisions + return c + logging.info("About to get the alimeeting cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/alimeeting/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "alimeeting-far_cuts_train.jsonl.gz" + ) + cuts = cuts.map(reduce_supervisions) + + return cuts.drop_features(), 140, 186364 + + @cached_property + def dataset_duration_stats(self): + stats_file = f"{self.args.shar_dir}/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"{self.args.shar_dir}/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts.drop_features() + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/eval", + shuffle_shards=False, + ) + return cuts + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/attention_decoder.py b/egs/emilia/CLAP/spear/attention_decoder.py new file mode 100644 index 0000000000..bff536f90b --- /dev/null +++ b/egs/emilia/CLAP/spear/attention_decoder.py @@ -0,0 +1,583 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List, Optional + +import k2 +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from scaling import penalize_abs_values_gt + +from icefall.utils import add_eos, add_sos, make_pad_mask + + +class AttentionDecoderModel(nn.Module): + """ + Args: + vocab_size (int): Number of classes. + decoder_dim: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + num_heads (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + sos_id: int = 1, + eos_id: int = 1, + dropout: float = 0.1, + ignore_id: int = -1, + label_smoothing: float = 0.1, + ): + super().__init__() + self.eos_id = eos_id + self.sos_id = sos_id + self.ignore_id = ignore_id + + # For the segment of the warmup period, we let the Embedding + # layer learn something. Then we start to warm up the other encoders. + self.decoder = TransformerDecoder( + vocab_size=vocab_size, + d_model=decoder_dim, + num_decoder_layers=num_decoder_layers, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + + # Used to calculate attention-decoder loss + self.loss_fun = LabelSmoothingLoss( + ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum" + ) + + def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor): + """Prepare ys_in_pad and ys_out_pad.""" + ys_in = add_sos(ys, sos_id=self.sos_id) + # [B, S+1], start with SOS + ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id) + ys_in_lens = ys_lens + 1 + + ys_out = add_eos(ys, eos_id=self.eos_id) + # [B, S+1], end with EOS + ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id) + + return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64) + + def calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys: k2.RaggedTensor, + ys_lens: torch.Tensor, + ) -> torch.Tensor: + """Calculate attention-decoder loss. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: The attention-decoder loss. + """ + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + loss = self.loss_fun(x=decoder_out, target=ys_out_pad) + return loss + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + token_ids: List[List[int]], + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from attention-decoder. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: A tensor of shape (batch, num_tokens). + """ + ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device) + row_splits = ys.shape.row_splits(1) + ys_lens = row_splits[1:] - row_splits[:-1] + + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + batch_size, _, num_classes = decoder_out.size() + nll = nn.functional.cross_entropy( + decoder_out.view(-1, num_classes), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + return nll + + +class TransformerDecoder(nn.Module): + """Transfomer decoder module. + + Args: + vocab_size: output dim + d_model: decoder dimension + num_decoder_layers: number of decoder layers + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) + + # Absolute positional encoding + self.pos = PositionalEncoding(d_model, dropout_rate=0.1) + + self.num_layers = num_decoder_layers + self.layers = nn.ModuleList( + [ + DecoderLayer( + d_model=d_model, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + for _ in range(num_decoder_layers) + ] + ) + + self.output_layer = nn.Linear(d_model, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + memory: Optional[torch.Tensor] = None, + memory_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, tgt_len). + x_lens: A tensor of shape (batch,) containing the number of tokens in `x` + before padding. + memory: + Memory sequence of shape (batch, src_len, memory_dim). + memory_lens: + A tensor of shape (batch,) containing the number of frames in + `memory` before padding. + + Returns: + Decoded token logits before softmax (batch, tgt_len, vocab_size) + """ + x = self.embed(x) # (batch, tgt_len, embed_dim) + x = self.pos(x) # (batch, tgt_len, embed_dim) + + x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim) + + # construct attn_mask for self-attn modules + padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) + causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) + + if memory is not None: + memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) + # construct memory_attn_mask for cross-attn modules + memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len) + memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len) + else: + memory_attn_mask = None + + for i, mod in enumerate(self.layers): + x = mod( + x, + attn_mask=attn_mask, + memory=memory, + memory_attn_mask=memory_attn_mask, + ) + + x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size) + x = self.output_layer(x) + + return x + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + d_model: equal to decoder_dim, total dimension of the decoder + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + d_model: int = 512, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + + self.norm_self_attn = nn.LayerNorm(d_model) + self.self_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, dropout=0.0 + ) + + self.norm_src_attn = nn.LayerNorm(d_model) + self.src_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0 + ) + + self.norm_ff = nn.LayerNorm(d_model) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, feedforward_dim), + Swish(), + nn.Dropout(dropout), + nn.Linear(feedforward_dim, d_model), + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input sequence of shape (seq_len, batch, embed_dim). + attn_mask: A binary mask for self-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + memory: Memory sequence of shape (seq_len, batch, memory_dim). + memory_attn_mask: A binary mask for cross-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + """ + # self-attn module + qkv = self.norm_self_attn(x) + self_attn_out = self.self_attn( + query=qkv, key=qkv, value=qkv, attn_mask=attn_mask + ) + x = x + self.dropout(self_attn_out) + + # cross-attn module + q = self.norm_src_attn(x) + src_attn_out = self.src_attn( + query=q, key=memory, value=memory, attn_mask=memory_attn_mask + ) + x = x + self.dropout(src_attn_out) + + # feed-forward module + x = x + self.dropout(self.feed_forward(self.norm_ff(x))) + + return x + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, but must be a multiple of num_heads. + num_heads: number of parallel attention heads. + memory_dim: dimension of memory embedding, optional. + dropout: a Dropout layer on attn_output_weights. + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + memory_dim: Optional[int] = None, + dropout: float = 0.0, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.head_dim = attention_dim // num_heads + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + self.linear_v = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + + self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute dot product attention. + + Args: + query: Query tensor of shape (tgt_len, batch, embed_dim). + key: Key tensor of shape (src_len, batch, embed_dim or memory_dim). + value: Value tensor of shape (src_len, batch, embed_dim or memory_dim). + key_padding_mask: A binary mask indicating which elements are padding. + Its shape is (batch, src_len). + attn_mask: A binary mask indicating which elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + + Returns: + Output tensor of shape (tgt_len, batch, embed_dim). + """ + num_heads = self.num_heads + head_dim = self.head_dim + + tgt_len, batch, _ = query.shape + src_len = key.shape[0] + + q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim) + k = self.linear_k(key) # (src_len, batch, num_heads * head_dim) + v = self.linear_v(value) # (src_len, batch, num_heads * head_dim) + + q = q.reshape(tgt_len, batch, num_heads, head_dim) + q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim) + k = k.reshape(src_len, batch, num_heads, head_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len) + v = v.reshape(src_len, batch, num_heads, head_dim) + v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1) + + # Note: could remove the scaling operation when using ScaledAdam + # (batch, head, tgt_len, src_len) + attn_weights = torch.matmul(q, k) / math.sqrt(head_dim) + + # From zipformer.py: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + + if attn_mask is not None: + assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == ( + batch, + tgt_len, + src_len, + ), attn_mask.shape + attn_weights = attn_weights.masked_fill( + attn_mask.unsqueeze(1), float("-inf") + ) + + attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + # (batch * head, tgt_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + assert attn_output.shape == ( + batch * num_heads, + tgt_len, + head_dim, + ), attn_output.shape + + attn_output = attn_output.transpose(0, 1).contiguous() + attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) + + # (batch, tgt_len, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def subsequent_mask(size, device="cpu", dtype=torch.bool): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def _test_attention_decoder_model(): + m = AttentionDecoderModel( + vocab_size=500, + decoder_dim=512, + num_decoder_layers=6, + attention_dim=512, + num_heads=8, + feedforward_dim=2048, + memory_dim=384, + dropout=0.1, + sos_id=1, + eos_id=1, + ignore_id=-1, + ) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + m.eval() + encoder_out = torch.randn(2, 50, 384) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] + + nll = m.nll(encoder_out, encoder_out_lens, token_ids) + print(nll) + + +if __name__ == "__main__": + _test_attention_decoder_model() diff --git a/egs/emilia/CLAP/spear/augmentations.py b/egs/emilia/CLAP/spear/augmentations.py new file mode 100644 index 0000000000..76ecb59a4d --- /dev/null +++ b/egs/emilia/CLAP/spear/augmentations.py @@ -0,0 +1,228 @@ +import random +import time + +from lhotse.cut import CutSet, MonoCut, Cut +from lhotse.cut.set import mix + +def _mix_with_offset_deprecated_( + reference_cut: Cut, + mixed_in_cut: Cut, + snr: float = 10.0, + drop_mixed_in_supervision: bool = True +): + if drop_mixed_in_supervision: + mixed_in_cut = mixed_in_cut.drop_supervisions() + ref_duration = reference_cut.duration + mixed_in_duration = mixed_in_cut.duration + + mix_duration = random.uniform(0.1, ref_duration / 2) # 0.1 for safety + + # randomly truncate the mixed_in_cut to mix_duration if longer + if mixed_in_duration > mix_duration: + diff = max(0.0, mixed_in_duration - mix_duration - 0.05) + truncate_start = random.uniform(0, diff) + mixed_in_cut = mixed_in_cut.truncate(offset=truncate_start, duration=mix_duration) + + actual_mix_duration = min(mixed_in_cut.duration, mix_duration) + offset = random.uniform(0, ref_duration - actual_mix_duration - 0.05) # a tolerance of 0.05 for safety + mixed_cut = mix( + reference_cut=reference_cut, + mixed_in_cut=mixed_in_cut, + offset=offset, + snr=snr, + preserve_id="left", + ) + + return mixed_cut + +def mix_with_offset( + reference_cut: Cut, + mixed_in_cut: Cut, + snr: float = 10.0, + drop_mixed_in_supervision: bool = True, + *, + # 仅对“语音重叠”模式——如果是噪声注入,建议另行分支处理 + min_overlap_ratio: float = 0.20, # 下限 + max_overlap_ratio: float = 0.50, # 上限 + epsilon: float = 0.01, # tolerance +): + if drop_mixed_in_supervision and hasattr(mixed_in_cut, "drop_supervisions"): + mixed_in_cut = mixed_in_cut.drop_supervisions() + + ref_duration = float(reference_cut.duration) + if ref_duration <= (0.1 + epsilon): + return reference_cut # 极短段保护 + + # 计算严格 < 50% 的上界 + max_allowed = max(0.0, max_overlap_ratio * ref_duration - epsilon) + min_allowed = max(0.1, min_overlap_ratio * ref_duration) # 0.1s 安全下限与原逻辑一致 + if max_allowed < min_allowed: # 容错:极短段或参数不当 + min_allowed = max(0.05, min_allowed * 0.5) + max_allowed = max(min_allowed + epsilon, max_allowed) + + mix_duration = random.uniform(min_allowed, max_allowed) + + # 截断被混入段以满足目标 mix_duration + mixed_in_duration = float(mixed_in_cut.duration) + if mixed_in_duration > mix_duration: + # 在可行范围内随机起点后截断 + slack = max(0.0, mixed_in_duration - mix_duration - epsilon) + truncate_start = random.uniform(0.0, slack) if slack > 0 else 0.0 + mixed_in_cut = mixed_in_cut.truncate(offset=truncate_start, duration=mix_duration) + + actual_mix_duration = min(float(mixed_in_cut.duration), mix_duration) + + # 将被混入段完全放入参考段内部,保证真实 overlap = actual_mix_duration + hi = max(0.0, ref_duration - actual_mix_duration - epsilon) + offset = random.uniform(0.0, hi) if hi > 0 else 0.0 + + mixed_cut = mix( + reference_cut=reference_cut, + mixed_in_cut=mixed_in_cut, + offset=offset, + snr=snr, + preserve_id="left", + ) + return mixed_cut + + +class BatchMixing: + def __init__( + self, + min_snr: float = -5, + max_snr: float = 5, + min_noise_snr: float = -5, + p: float = 0.2, + p_noise: float = 0.1, + noise_cuts: CutSet = None, + drop_mixed_in_supervision: bool = True, + seed: int = 42, + stateful: bool = True, + ): + """perform in-batch mixing with the cuts from the same batch + + Args: + min_snr (float): minimum mix SNR for in-batch speech mixing + max_snr (float): maximum mix SNR + min_noise_snr (float): minimum mix SNR for noise mixing + p_noise (float, optional): The probability of perform noise mixing instead of in-batch. + p (float, optional): The probability of perform mixing to a cut. Defaults to 0.5. + noise_cuts (CutSet, optional): An optional noise cut. If provided, sample from the noise cuts instead of from the batch itself + drop_mixed_in_supervision (bool, optional): Remove the supervisions in the mixed_in_cut. Defaults to True. + """ + self.min_snr = min_snr + self.max_snr = max_snr + self.p = p + self.min_noise_snr = min_noise_snr + self.p_noise = p_noise + if p_noise > 0: + assert noise_cuts is not None, "If p_noise > 0, noise_cuts must be provided" + self.noise_cuts = noise_cuts + self.drop_mixed_in_supervision = drop_mixed_in_supervision + + self.seed = seed + self.stateful = stateful + self.num_times_iterated = 0 + + def __str__(self): + return f"BatchMixing: p={self.p}, snr=({self.min_snr}, {self.max_snr}), p_n={self.p_noise}, min_noise_snr={self.min_noise_snr}, drop_supervision={self.drop_mixed_in_supervision}" + + def __call__(self, reference_cuts: CutSet) -> CutSet: + from lhotse.dataset.dataloading import resolve_seed + + if isinstance(self.seed, random.Random): + rng = self.seed + else: + rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated + int(time.time() * 1000) % 100000) + + if self.stateful: + self.num_times_iterated += 1 + + if self.noise_cuts.is_lazy: + # If the noise input is lazy, we'll shuffle it approximately. + # We set the shuffling buffer size to 2000 because that's the size of MUSAN, + # so even if the user forgets to convert MUSAN to an eager manifest, they will + # get roughly the same quality of noise randomness. + # Note: we can't just call .to_eager() as the noise CutSet can technically be + # very large, or even hold data in-memory in case of webdataset/Lhotse Shar sources. + def noise_gen(): + yield from self.noise_cuts.repeat().shuffle(rng=rng, buffer_size=2000) + else: + # Eager nose cuts are just fully reshuffled in a different order on each noise "epoch". + def noise_gen(): + while True: + yield from self.noise_cuts.shuffle(rng=rng) + + noise_cuts = iter(noise_gen()) + results = [] + for cut in reference_cuts: + # perform augmentation + if rng.random() < self.p: + if self.p_noise > 0 and rng.random() < self.p_noise: + snr = rng.uniform(self.min_noise_snr, 20) # the max snr for noise mixing is 20dB + mixed_in_cut = next(noise_cuts) + mixed_cut = mix_with_offset( + cut, + mixed_in_cut, + snr=snr, + max_overlap_ratio=0.8, # noise 可以覆盖更多 + ) + # mixed_in_cut = self.noise_cuts.sample(n_cuts=1) # this should be rather quick + else: # same batch mixing + snr = rng.uniform(self.min_snr, self.max_snr) + mixed_in_cut = reference_cuts.sample(n_cuts=1) # this should be rather quick + while mixed_in_cut.id == cut.id: + mixed_in_cut = reference_cuts.sample(n_cuts=1) + mixed_cut = mix_with_offset( + cut, + mixed_in_cut, + snr=snr, + min_overlap_ratio=0.2, + max_overlap_ratio=0.5, + ) + results.append(mixed_cut) + else: + results.append(cut) + return CutSet.from_cuts(results) + +def _test_mix(): + from lhotse import load_manifest_lazy + from lhotse import load_manifest + manifest = "data/fbank/librispeech_cuts_dev-other.jsonl.gz" + noise_cuts = "data/musan/noise_non_speech_musan_audioset.jsonl.gz" + + cuts = load_manifest_lazy(manifest).subset(first=200).drop_features() + noise_cuts = load_manifest(noise_cuts).drop_features() + + from lhotse.cut import MixedCut + transform = BatchMixing( + min_snr=-5, + max_snr=5, + p=0.2, + min_noise_snr=5, + p_noise=0.5, + noise_cuts=noise_cuts, + drop_mixed_in_supervision=True + ) + # import pdb; pdb.set_trace() + start = time.time() + for i in range(1): + mixed_cuts = transform(cuts) + mix_durations = [] + for j, c in enumerate(mixed_cuts): + if isinstance(c, MixedCut): + mix_durations.append(c.tracks[1].cut.duration) + end = time.time() + print(f"Elasped: {end - start} seconds") + # print(sum(mix_durations)/len(mix_durations)) + + # print(mixed_cuts) + + # MixedCut( + # id='2067-143536-0050-22591_repeat0', + # tracks=[ + # MixTrack(cut=MonoCut(id='2067-143536-0050-22591_repeat0', start=0, duration=15.905, channel=0, supervisions=[SupervisionSegment(id='2067-143536-0050', recording_id='2067-143536-0050', start=0.0, duration=15.905, channel=0, text='AND MEN AND WOMEN MUTES WATCHING WITH HARD CURIOUS EYES THEN SEATED IN HER BARBARIC CHAIR ABOVE THEM ALL WITH MYSELF AT HER FEET WAS THE VEILED WHITE WOMAN WHOSE LOVELINESS AND AWESOME POWER SEEMED TO VISIBLY SHINE ABOUT HER LIKE A HALO', language='English', speaker='2067', gender=None, custom=None, alignment=None)], features=Features(type='kaldi-fbank', num_frames=1591, num_features=128, frame_shift=0.01, sampling_rate=16000, start=0, duration=15.905, storage_type='lilcom_chunky', storage_path='data/fbank/librispeech_feats_train-other-500/feats-1.lca', storage_key='272808444,75670,73700,73580,13443', recording_id='None', channels=0), recording=Recording(id='2067-143536-0050', sources=[AudioSource(type='file', channels=[0], source='/mnt/workspace/xiaoyu/workspace/icefall_prompt_multi_task/egs/librispeech/ASR/download/LibriSpeech/train-other-500/2067/143536/2067-143536-0050.flac')], sampling_rate=16000, num_samples=254480, duration=15.905, channel_ids=[0], transforms=None), custom={'codebook_indexes': TemporalArray(array=Array(storage_type='numpy_hdf5', storage_path='data_hdf5/vq_hubert_large_layer_21_normalize_1_cb_16/librispeech_cuts_train-all-shuf/librispeech_cuts_train-all-shuf-1.h5', storage_key='2067-143536-0050-22591', shape=[795, 16]), temporal_dim=0, frame_shift=0.02, start=0), 'shard_origin': PosixPath('data-shar/data-shar-hubert-large-layer-21-normalize-cb16-hdf5/librispeech/train-all-shuf/cuts.000083.jsonl.gz'), 'shar_epoch': 0, 'task_id': 1, 'dataloading_info': {'rank': 1, 'world_size': 8, 'worker_id': 1}}), type='MonoCut', offset=0.0, snr=None), + # MixTrack(cut=MonoCut(id='4ed48ac9-a0df-e8e9-de68-e8540e51ae78', start=10.9971875, duration=0.58, channel=0, supervisions=[], features=Features(type='kaldi-fbank', num_frames=1588, num_features=128, frame_shift=0.01, sampling_rate=16000, start=0, duration=15.875, storage_type='lilcom_chunky', storage_path='data/fbank/librispeech_feats_train-other-500/feats-8.lca', storage_key='96050666,74833,75029,74339,13685', recording_id='None', channels=0), recording=Recording(id='8346-244446-0072', sources=[AudioSource(type='file', channels=[0], source='/mnt/workspace/xiaoyu/workspace/icefall_prompt_multi_task/egs/librispeech/ASR/download/LibriSpeech/train-other-500/8346/244446/8346-244446-0072.flac')], sampling_rate=16000, num_samples=254000, duration=15.875, channel_ids=[0], transforms=None), custom={'codebook_indexes': TemporalArray(array=Array(storage_type='numpy_hdf5', storage_path='data_hdf5/vq_hubert_large_layer_21_normalize_1_cb_16/librispeech_cuts_train-all-shuf/librispeech_cuts_train-all-shuf-3.h5', storage_key='8346-244446-0072-7733', shape=[793, 16]), temporal_dim=0, frame_shift=0.02, start=0), 'shard_origin': PosixPath('data-shar/data-shar-hubert-large-layer-21-normalize-cb16-hdf5/librispeech/train-all-shuf/cuts.000128.jsonl.gz'), 'shar_epoch': 0, 'task_id': 1, 'dataloading_info': {'rank': 1, 'world_size': 8, 'worker_id': 1}}), type='MonoCut', offset=7.163035072927992, snr=2.8853809253275147)], transforms=None) + +if __name__=="__main__": + _test_mix() \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/beam_search.py b/egs/emilia/CLAP/spear/beam_search.py new file mode 120000 index 0000000000..e24eca39f2 --- /dev/null +++ b/egs/emilia/CLAP/spear/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/collect_zipformer_embeddings.py b/egs/emilia/CLAP/spear/collect_zipformer_embeddings.py new file mode 100644 index 0000000000..780a08d173 --- /dev/null +++ b/egs/emilia/CLAP/spear/collect_zipformer_embeddings.py @@ -0,0 +1,361 @@ + +import argparse +import os +import logging +from typing import Union, List, Dict +from pathlib import Path + +from train_multi_KD3_shar import add_model_arguments, get_encoder_embed, get_encoder_model +from zipformer2 import Zipformer2 + +import torch +import torch.multiprocessing as mp +from torch.utils.data import DataLoader + +from lhotse import load_manifest, CutSet +from lhotse.cut import MonoCut +from lhotse import Fbank, FbankConfig +from lhotse.dataset import DynamicBucketingSampler +from lhotse.dataset.input_strategies import BatchIO, OnTheFlyFeatures, PrecomputedFeatures +from lhotse.features.io import NumpyHdf5Writer +from lhotse.workarounds import Hdf5MemoryIssueFix + +from icefall.utils import AttributeDict, setup_logger, make_pad_mask + +class FbankDataset(torch.utils.data.Dataset): + def __init__( + self, + return_cuts: bool = True, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.input_strategy = input_strategy + + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + batch = { + "feature": inputs, + } + batch.update(supervision_intervals) + + if self.return_cuts: + batch["cuts"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + return batch + +class ZipformerModel(torch.nn.Module): + def __init__( + self, encoder_embed: torch.nn.Module, encoder: Zipformer2 + ): + super().__init__() + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.encoder_dim = encoder.encoder_dim + + def _get_full_dim_output_impl(self, outputs: List[torch.Tensor], max_depth): + output_dim = max(self.encoder_dim[:max_depth]) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[max_depth - 1] + + for i in range(max_depth - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def _get_full_dim_output(self, outputs: List[torch.Tensor], max_depth: int): + outputs = outputs[:max_depth] + return self._get_full_dim_output_impl(outputs, max_depth=max_depth) + + def get_embeddings(self, batch, layer_idx: int = -1): + device = next(self.parameters()).device + x = batch["feature"].to(device) + x_lens = batch["num_frames"].to(device) + + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens, layer_results = self.encoder( + x, x_lens, src_key_padding_mask, return_middle_out=True + ) + + if layer_idx == -1: + feature = encoder_out.permute(1, 0, 2) + else: + # the intermediate layers' feature are 50 Hz + feature = self._get_full_dim_output(layer_results, layer_idx) + # feature = layer_results[layer_idx-1] # index starts from 1 + feature = feature.permute(1, 0, 2) + encoder_out_lens = x_lens + + return feature, encoder_out_lens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=1, + ) + + parser.add_argument( + "--input-manifest", + type=str, + required=True, + ) + + parser.add_argument( + "--manifest-name", + type=str, + required=True, + help="name of the manifest, e.g embeddings-dev-clean, embeddings-train-clean-100" + ) + + parser.add_argument( + "--embedding-dir", + type=str, + default="data/embeddings" + ) + + parser.add_argument( + "--embedding-layer", + type=int, + default=-1, + help="Which layer's representation should be extracted, index start from 1, i.e the 10-th layer requires" + "--embedding-layer 10" + ) + + parser.add_argument( + "--max-duration", + type=int, + default=500, + ) + + parser.add_argument( + "--target-manifest-file", + type=str, + required=True, + help="Where to store the manifest augmented with whisper features" + ) + + # zipformer related args + parser.add_argument( + "--model-ckpt", + type=str, + required=True, + ) + + parser.add_argument( + "--zipformer-version", + type=str, + default="300m", + ) + + parser.add_argument( + "--frame-shift", + type=float, + default=0.02, + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=128, + ) + add_model_arguments(parser) + + return parser + +@torch.no_grad() +def extract_embeddings( + rank: int, + manifest: str, + params: AttributeDict, +): + setup_logger(f"data/embeddings/log/log-zipformer-embeddings") + if params.num_jobs > 1: + manifest = manifest[rank] + output_manifest = params.embedding_dir / f"zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}-{rank}.jsonl.gz" + embedding_path = params.embedding_dir / f'zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}-{rank}' + else: + output_manifest = params.embedding_dir / f"zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}.jsonl.gz" + embedding_path = params.embedding_dir / f'zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}' + + device = torch.device("cuda", rank) + + # currently only use the encoder of zipformer + logging.info(params) + model = ZipformerModel( + encoder_embed=get_encoder_embed(params), + encoder=get_encoder_model(params), + ) + state_dict = torch.load(params.model_ckpt)["model"] + load_info = model.load_state_dict(state_dict, strict=False) + logging.info(load_info) + + model.to(device) + model.eval() + logging.info(f"Number of zipformer model params: {sum(p.numel() for p in model.parameters())}") + logging.info(f"Successfully loaded zipformer model.") + + dataset = FbankDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=True + ) + + sampler = DynamicBucketingSampler( + manifest, + max_duration=params.max_duration, + shuffle=False, + num_buckets=20, + buffer_size=20 * 2000, + shuffle_buffer_size=20 * 5000, + drop_last=False, + ) + + dl = DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + new_cuts = [] + num_cuts = 0 + + with NumpyHdf5Writer(embedding_path) as writer: + logging.info(f"Writing zipformer embeddings to {embedding_path}") + for i, batch in enumerate(dl): + cuts = batch["cuts"] + + with torch.cuda.amp.autocast(enabled=True): + embeddings, embedding_lens = model.get_embeddings( + batch=batch, + layer_idx=params.embedding_layer # which layer's embedding to be stored + ) + embeddings = embeddings.detach().to("cpu").numpy() + + for idx, cut in enumerate(cuts): + new_cut = MonoCut( + id=cut.id, + start=cut.start, + duration=cut.duration, + channel=cut.channel, + ) + new_cut.embedding = writer.store_array( + key=cut.id, + value=embeddings[idx][: embedding_lens[idx]], + temporal_dim=0, + frame_shift=params.frame_shift, + start=cut.start, + ) + new_cuts.append(new_cut) + num_cuts += 1 + if num_cuts and i % 100 == 0: + logging.info(f"Cuts processed until now: {num_cuts}") + + logging.info(f"Finished extracting zipformer embeddings, processed a total of {num_cuts} cuts.") + + CutSet.from_cuts(new_cuts).to_jsonl(output_manifest) + logging.info(f"Saved manifest to {output_manifest}") + +def join_manifests( + input_cuts: CutSet, + embedding_manifest: str, + output_dir: str, +): + # Combine the teacher embedding manifest with the original manifest for ASR + embedding_cuts = load_manifest(embedding_manifest) + + assert len(embedding_cuts) == len(input_cuts) + assert set(input_cuts.ids) == set(embedding_cuts.ids) + + embedding_cuts = embedding_cuts.sort_like(input_cuts) + for cut_idx, (ori_cut, embed_cut) in enumerate(zip(input_cuts, embedding_cuts)): + assert ori_cut.id == embed_cut.id + ori_cut.embedding = embed_cut.embedding + + input_cuts.to_jsonl(output_dir) + print(f"Saved the joined manifest to {output_dir}") + +def remove_short_and_long_utt(c): + if c.duration < 1.0 or c.duration > 29.9: + return False + return True + +def remove_sp(c): + if "sp1.1" in c.id or "sp0.9" in c.id: + return False + return True + + +if __name__=="__main__": + parser = get_parser() + args = parser.parse_args() + params = AttributeDict() + params.update(vars(args)) + params.embedding_dir = Path(params.embedding_dir) + + nj = params.num_jobs + cuts = load_manifest(params.input_manifest) + cuts = cuts.filter(remove_short_and_long_utt) # remove audio longer than 30s + cuts = cuts.filter(remove_sp) # remove the speed perturbed audio + print(f"Finished loading manifest") + + embedding_manifest = params.embedding_dir / f"zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}.jsonl.gz" + + if not embedding_manifest.exists(): + if nj == 1: + extract_embeddings( + rank=0, + manifest=cuts, + params=params, + ) + else: + splitted_cuts = cuts.split(num_splits=nj) + print(f"Finished splitting manifest") + mp.spawn(extract_embeddings, args=(splitted_cuts, params), nprocs=nj, join=True) + manifests = f"{str(params.embedding_dir)}/zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}-*.jsonl.gz" + os.system(f"lhotse combine {manifests} {embedding_manifest}") + else: + print(f"Skip embedding extraction: the manifest is already generated.") + + output_manifest = params.target_manifest_file + if not os.path.exists(output_manifest): + join_manifests( + input_cuts=cuts, + embedding_manifest=embedding_manifest, + output_dir=output_manifest, + ) + + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/dataset.py b/egs/emilia/CLAP/spear/dataset.py new file mode 100644 index 0000000000..f0bd84fbe1 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset.py @@ -0,0 +1,311 @@ +import math +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + mvq_KD: bool = False, + at_KD: bool = False, + sv_KD: bool = False + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.mvq_KD = mvq_KD + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + # validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + if self.mvq_KD: + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + pad_value=-100, + ) + else: + mvq_tokens = None + mvq_token_lens = None + + if self.at_KD: + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [ + c.supervisions[0].audio_event if hasattr(c.supervisions[0], "audio_event") + else "0" for c in cuts_pre_mixed + ] # the label indices are in CED format + at_targets, _ = str2multihot(audio_events) # (N, num_events) + + # TODO: SV targets + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.codebook_indexes.array.storage_key != "dummy_whisper_codebook_indexes_1510" + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.beats_embedding.storage_key != "dummy_beats_embedding" + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: np.array = None, + temporal_array: bool = True, + pad_value=None, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * 50) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + + +if __name__=="__main__": + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("debug.jsonl.gz") + cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + cuts = load_manifest("debug.jsonl.gz") + + gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + import pdb; pdb.set_trace() + print(gt_mvq_tokens) + + gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + print(beats_embed) + diff --git a/egs/emilia/CLAP/spear/dataset2.py b/egs/emilia/CLAP/spear/dataset2.py new file mode 100644 index 0000000000..7baccb3864 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset2.py @@ -0,0 +1,209 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + at_KD: bool = False, + sv_KD: bool = False + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.at_KD = at_KD + self.sv_KD = sv_KD + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + mvq_tokens, mvq_token_lens = collate_custom_field(cuts_pre_mixed, "codebook_indexes", pad_value=-100) + + if self.at_KD: + at_targets = collate_custom_field( + cuts_pre_mixed, "beats_embedding", pad_value=-100 + ) # (N,C) + else: + audio_events = [c.supervisions[0].audio_event for c in cuts_pre_mixed] # the label indices are in CED format + at_targets, _ = str2multihot(audio_events) # (N, num_events) + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.codebook_indexes.array.storage_key != "dummy_whisper_codebook_indexes_1510" + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.beats_embedding.storage_key != "dummy_beats_embedding" \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/dataset2_batch_mixing.py b/egs/emilia/CLAP/spear/dataset2_batch_mixing.py new file mode 100644 index 0000000000..abd8ad72f3 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset2_batch_mixing.py @@ -0,0 +1,646 @@ +import math +import random +from threading import Lock +from typing import Callable, Dict, List, Optional, Union, Tuple + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse import Fbank, FbankConfig +from lhotse.cut import CutSet, MonoCut, Cut, MixedCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field, collate_matrices +from lhotse.utils import compute_num_frames, ifnone, LOG_EPSILON +from lhotse.workarounds import Hdf5MemoryIssueFix + +from lhotse.cut.set import mix + +class CodebookCache: + """ + Cache of 'bytes' objects with audio data. + It is used to cache the "command" type audio inputs. + + By default it is disabled, to enable call `set_caching_enabled(True)` + or `AudioCache.enable()`. + + The cache size is limited to max 100 elements and 500MB of audio. + + A global dict `__cache_dict` (static member variable of class AudioCache) + is holding the codebooks as np.array. + The key is the supervision ID, we avoid using cut.id because the cut IDs could be ruined by repeat + + Thread-safety is ensured by a threading.Lock guard. + """ + + __enabled: bool = False + + max_cache_memory: int = 500 * 1e6 # 500 MB + max_cache_elements: int = 10000 # number audio files + + __cache_dict: Dict[str, np.array] = {} + __lock: Lock = Lock() + + @classmethod + def enable(cls, enabled=True): + cls.__enabled = enabled + if not enabled: + cls.__clear_cache() + + @classmethod + def enabled(cls) -> bool: + return cls.__enabled + + @classmethod + def try_cache(cls, key: str) -> Optional[bytes]: + """ + Test if 'key' is in the chache. If yes return the bytes array, + otherwise return None. + """ + + if not cls.__enabled: + return None + + with cls.__lock: + if key in cls.__cache_dict: + return cls.__cache_dict[key] + else: + return None + + @classmethod + def add_to_cache(cls, key: str, value: np.array): + """ + Add the new (key,value) pair to cache. + Possibly free some elements before adding the new pair. + The oldest elements are removed first. + """ + + if not cls.__enabled: + return None + + if value.itemsize * value.size > cls.max_cache_memory: + return + + with cls.__lock: + # limit cache elements + while len(cls.__cache_dict) > cls.max_cache_elements: + # remove oldest elements from cache + # (dict pairs are sorted according to insertion order) + cls.__cache_dict.pop(next(iter(cls.__cache_dict))) + + # limit cache memory + while value.itemsize * value.size + CodebookCache.__cache_memory() > cls.max_cache_memory: + # remove oldest elements from cache + # (dict pairs are sorted according to insertion order) + cls.__cache_dict.pop(next(iter(cls.__cache_dict))) + + # store the new (key,value) pair + cls.__cache_dict[key] = value + + @classmethod + def __cache_memory(cls) -> int: + """ + Return size of CodebookCache values in bytes. + (internal, not to be called from outside) + """ + ans = 0 + for key, value in cls.__cache_dict.items(): + ans += value.itemsize * value.size + return ans + + @classmethod + def __clear_cache(cls) -> None: + """ + Clear the cache, remove the data. + """ + with cls.__lock: + cls.__cache_dict.clear() + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + if not isinstance(events, list): + events = [events] + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + target_frame_rate: int = 50, + at_KD: bool = False, + sv_KD: bool = False, + enable_cache: bool = True, + token_mixing: bool = False + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param enable_cache: Enables a cache for the codebook indexes + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + self.extractor = Fbank(FbankConfig(num_mel_bins=128)) + + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.target_frame_rate = target_frame_rate + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + self.enable_cache = enable_cache + if self.enable_cache: + CodebookCache.enable() + assert CodebookCache.enabled() + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + self.token_mixing = token_mixing + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + # validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + inputs, input_lens, mix_ratios = self.load_audio_and_compute_fbank(cuts) + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = fix_start(cuts_pre_mixed) + + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + target_frame_rate=self.target_frame_rate, + pad_value=-100, + ) + + # perform token mixing + if self.token_mixing: + mvq_tokens = self.mix_mvq_tokens(mvq_tokens, cuts, mix_ratios) + + if self.at_KD: + # at_targets = collate_custom_field( + # cuts_pre_mixed, "beats_embedding", pad_value=-100 + # ) # (N,C) + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [getattr(c.supervisions[0], "audio_event", "0") for c in cuts_pre_mixed] # the label indices are in CED format + # at_targets, _ = str2multihot(audio_events) # (N, num_events) + at_targets = None + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + def load_audio_and_compute_fbank(self, cuts: CutSet): + audios = [] + mix_ratios = [] + for cut in cuts: + if isinstance(cut, MixedCut): + audio, mix_ratio = _load_mixed_cut_single(cut) + else: + audio = cut.load_audio() + mix_ratio = 0.0 + audios.append(audio) + mix_ratios.append(mix_ratio) + + inputs, input_lens = compute_feature(audios, cuts, self.extractor) + + return inputs, input_lens, mix_ratios + + def load_codebook_indexes(self, cuts: CutSet, field_name: str = "codebook_indexes"): + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = fix_start(cuts_pre_mixed) + + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + field_name, + dummy=self.dummy_codebook_indexes, + temporal_array=True, + target_frame_rate=self.target_frame_rate, + pad_value=-100, + ) + if self.mix_codebook_indexes: + for i,c in enumerate(cuts): + if not isinstance(c, MixedCut): + continue + orig_track, mix_track = c.tracks # get the two tracks + mixed_in_cb = mix_track.load_custom(field_name) # should be only within the mix region + offset = compute_num_frames(mix_track.start, 1 / self.target_frame_rate) + return mvq_tokens, mvq_token_lens + + def mix_mvq_tokens( + self, + mvq_tokens: torch.Tensor, + cuts: CutSet, + mix_ratios: List[float], + field_name: str = "codebook_indexes" + ): + # Randomly replace a proportion of the original codebook indexes + # with the codebook indexes from the mixed cut. The proportion is determined + # by the gain of the mixed audio + for i,c in enumerate(cuts): + if not isinstance(c, MixedCut): + continue + orig_track, mix_track = c.tracks # get the two tracks + + # compute the starting mixing frame + offset = int(mix_track.offset * self.target_frame_rate) + mixed_in_cb = torch.from_numpy(mix_track.cut.load_custom(field_name)) # should be only within the mix region + mix_length = mixed_in_cb.shape[0] + if mix_length + offset >= mvq_tokens.size(1): + mix_length = mvq_tokens.size(1) - offset + mixed_in_cb = mixed_in_cb[:, :mix_length] + cur_cb_slice = mvq_tokens[i, offset:offset + mix_length, :] + p = gain2prob(mix_ratios[i]) + mixed_cb = _mix_tokens_single(cur_cb_slice, mixed_in_cb, p) + mvq_tokens[i, offset:offset + mix_length] = mixed_cb + return mvq_tokens + +def gain2prob(gain: float, alpha: float=2.0): + # x**alpha/(1+x**alpha), x is gain, alpha is empirically tuned + return gain ** alpha / (1 + gain**alpha) + +def audio_energy(audio: np.ndarray): + return float(np.average(audio**2)) + +def _load_mixed_cut_single(cut: MixedCut) -> Tuple[np.ndarray, float]: + # we only deal with the first channel + sample_rate = cut.sampling_rate + orig_cut = cut.tracks[0].cut + mix_in_cut = cut.tracks[1].cut + snr = cut.tracks[1].snr + + # compute some numbers + mix_offset = cut.tracks[1].offset + mix_offset_frames = int(sample_rate * mix_offset) # compute the frame shift for mixing + + # we take the first channel + orig_audio = orig_cut.load_audio() + mix_in_audio = mix_in_cut.load_audio() + mix_in_frames = mix_in_audio.shape[1] + + energy_orig = audio_energy(orig_audio[0, mix_offset_frames:mix_offset_frames + mix_in_frames]) + target_energy = energy_orig * (10.0 ** (-snr / 10)) + energy_mix_in = audio_energy(mix_in_audio) + gain = math.sqrt(target_energy / (energy_mix_in + 1e-8)) + + if mix_in_frames + mix_offset_frames <= orig_audio.shape[1]: + orig_audio[0, mix_offset_frames:mix_offset_frames + mix_in_frames] += gain * mix_in_audio[0] + else: + mix_in_frames = orig_audio.shape[1] - mix_offset_frames + orig_audio[0, mix_offset_frames:mix_offset_frames+mix_in_frames] += gain * mix_in_audio[0, :mix_in_frames] + + return orig_audio, gain + +def mix_audio_with_offset( + reference_cut: Cut, + mixed_in_cut: Cut, + snr: float = 10.0, + drop_mixed_in_supervision: bool = True +): + if drop_mixed_in_supervision: + mixed_in_cut = mixed_in_cut.drop_supervisions() + ref_duration = reference_cut.duration + mixed_in_duration = mixed_in_cut.duration + + mix_duration = random.uniform(0, ref_duration / 2) + + # randomly truncate the mixed_in_cut to mix_duration if longer + if mixed_in_duration > mix_duration: + diff = mixed_in_duration - mix_duration + truncate_start = random.uniform(0, diff) + mixed_in_cut = mixed_in_cut.truncate(offset=truncate_start, duration=mix_duration) + + actual_mix_duration = min(mixed_in_cut.duration, mix_duration) + offset = random.uniform(0, ref_duration - actual_mix_duration - 0.05) # a tolerance of 0.05 for safety + mixed_cut = mix( + reference_cut=reference_cut, + mixed_in_cut=mixed_in_cut, + offset=offset, + snr=snr, + preserve_id="left", + ) + + return mixed_cut + +def _mix_tokens_single(A: torch.Tensor, B: torch.Tensor, p: float) -> torch.Tensor: + """ + 从 A 中随机选出 p% 的位置,用 B 中对应位置的值替换。 + + 参数: + A (Tensor): 原始张量,形状为 (T, C) + B (Tensor): 替换来源张量,形状必须与 A 相同 + p (float): 替换比例,范围为 0~1 + + 返回: + Tensor: 替换后的新张量 + """ + assert A.shape == B.shape, "A and B must have the same shape" + assert 0 <= p <= 1, "p must be between 0 and 1" + + # 创建一个与 A 相同形状的 mask,表示哪些位置需要替换 + mask = torch.rand_like(A, dtype=torch.float32) < p + + # 创建新的张量:如果 mask 为 True,就用 B 的值,否则用 A 的值 + return torch.where(mask, B, A) + +def compute_feature(audios, cuts, extractor): + # compute features given the audios + # cuts is only for metadata reading + features_single = [] + for idx, (audio, cut) in enumerate(zip(audios, cuts)): + try: + features = extractor.extract(audio, cuts[idx].sampling_rate) + except: + print( + f"Error while extracting the features for cut with ID {cut.id} -- details:\n{cut}" + ) + raise + features_single.append(torch.from_numpy(features)) + + features_batch = collate_matrices(features_single, padding_value=LOG_EPSILON) + + feature_lens = torch.tensor( + [f.shape[0] for f in features_single], dtype=torch.int64 + ) + + out = (features_batch, feature_lens) + return out + +def fix_start(cuts): + # make the start of codebook indexes the same as the cut + new_cuts = [] + for cut in cuts: + if cut.has_custom("codebook_indexes") and (not isinstance(cut.codebook_indexes, dict)): + cut.codebook_indexes.start = cut.start + if cut.has_custom("firered_codebook_indexes") and (not isinstance(cut.firered_codebook_indexes, dict)): + cut.firered_codebook_indexes.start = cut.start + new_cuts.append(cut) + return new_cuts + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.codebook_indexes.array.storage_key != "dummy_whisper_codebook_indexes_1510" + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.beats_embedding.storage_key != "dummy_beats_embedding" + +def load_codebook_indexes(c): + info = c.codebook_indexes + cached_cb = CodebookCache.try_cache(c.supervisions[0].id) # we use supervision ID rather than cut id because cuts.repeat() ruins the cut id + if cached_cb is not None: + return cached_cb + else: + if isinstance(info, dict): + filename = info["path"] + with open(filename, "rb") as f: + cb_indexes = np.load(f) + # return np.load(filename, mmap_mode="r") + else: + cb_indexes = c.load_custom("codebook_indexes") + + CodebookCache.add_to_cache(c.supervisions[0].id, cb_indexes) + return cb_indexes + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: torch.Tensor = None, + temporal_array: bool = True, + target_frame_rate: int = 50, + pad_value=None, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * target_frame_rate) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(load_codebook_indexes(c)) if c.has_custom(field) else dummy for c in cuts # load the numpy codebook indexes + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + +def _test_mix(): + from lhotse import load_manifest_lazy + manifest = "data/fbank/librispeech_cuts_dev-other.jsonl.gz" + cuts = load_manifest_lazy(manifest).drop_features() + reference_cut = cuts[0] + noise_cuts = [cuts[4], cuts[2]] + + for noise_cut in noise_cuts: + mixed_cut = mix_audio_with_offset(reference_cut=reference_cut, mixed_in_cut=noise_cut, snr=5) + print(mixed_cut) + +if __name__=="__main__": + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + _test_mix() + + # enable the cache + CodebookCache.enable() + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("data/vq_hubert_large_layer_21_normalize_1_cb_16/librispeech_cuts_dev-clean.jsonl.gz").subset(first=500).repeat(2) + # cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + # cuts = load_manifest("debug.jsonl.gz") + + # gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + import time + start = time.time() + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + print(f"Cache: {CodebookCache.enabled()}, Time elapsed: {time.time() - start}") + # print(gt_mvq_tokens) + + # gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + # beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + # print(beats_embed) + diff --git a/egs/emilia/CLAP/spear/dataset2_dummy.py b/egs/emilia/CLAP/spear/dataset2_dummy.py new file mode 100644 index 0000000000..91700fda77 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset2_dummy.py @@ -0,0 +1,316 @@ +import math +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + at_KD: bool = False, + sv_KD: bool = False + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = fix_start(cuts_pre_mixed) + #mvq_tokens, mvq_token_lens = collate_custom_field(cuts_pre_mixed, "codebook_indexes", pad_value=-100) + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + pad_value=-100, + ) + + if self.at_KD: + # at_targets = collate_custom_field( + # cuts_pre_mixed, "beats_embedding", pad_value=-100 + # ) # (N,C) + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [c.supervisions[0].audio_event for c in cuts_pre_mixed] # the label indices are in CED format + at_targets, _ = str2multihot(audio_events) # (N, num_events) + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + +def fix_start(cuts): + # make the start of codebook indexes the same as the cut + new_cuts = [] + for cut in cuts: + if cut.has_custom("codebook_indexes"): + cut.codebook_indexes.start = cut.start + if cut.has_custom("firered_codebook_indexes"): + cut.firered_codebook_indexes.start = cut.start + new_cuts.append(cut) + return new_cuts + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.codebook_indexes.array.storage_key != "dummy_whisper_codebook_indexes_1510" + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.beats_embedding.storage_key != "dummy_beats_embedding" + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: np.array = None, + temporal_array: bool = True, + pad_value=None, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * 50) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + + +if __name__=="__main__": + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("debug.jsonl.gz") + cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + cuts = load_manifest("debug.jsonl.gz") + + gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + print(gt_mvq_tokens) + + gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + print(beats_embed) + diff --git a/egs/emilia/CLAP/spear/dataset2_npy.py b/egs/emilia/CLAP/spear/dataset2_npy.py new file mode 100644 index 0000000000..fd71f84673 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset2_npy.py @@ -0,0 +1,453 @@ +import math +from threading import Lock +from typing import Callable, Dict, List, Optional, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +class CodebookCache: + """ + Cache of 'bytes' objects with audio data. + It is used to cache the "command" type audio inputs. + + By default it is disabled, to enable call `set_caching_enabled(True)` + or `AudioCache.enable()`. + + The cache size is limited to max 100 elements and 500MB of audio. + + A global dict `__cache_dict` (static member variable of class AudioCache) + is holding the codebooks as np.array. + The key is the supervision ID, we avoid using cut.id because the cut IDs could be ruined by repeat + + Thread-safety is ensured by a threading.Lock guard. + """ + + __enabled: bool = False + + max_cache_memory: int = 500 * 1e6 # 500 MB + max_cache_elements: int = 5000 # number audio files + + __cache_dict: Dict[str, np.array] = {} + __lock: Lock = Lock() + + @classmethod + def enable(cls, enabled=True): + cls.__enabled = enabled + if not enabled: + cls.__clear_cache() + + @classmethod + def enabled(cls) -> bool: + return cls.__enabled + + @classmethod + def try_cache(cls, key: str) -> Optional[bytes]: + """ + Test if 'key' is in the chache. If yes return the bytes array, + otherwise return None. + """ + + if not cls.__enabled: + return None + + with cls.__lock: + if key in cls.__cache_dict: + return cls.__cache_dict[key] + else: + return None + + @classmethod + def add_to_cache(cls, key: str, value: np.array): + """ + Add the new (key,value) pair to cache. + Possibly free some elements before adding the new pair. + The oldest elements are removed first. + """ + + if not cls.__enabled: + return None + + if value.itemsize * value.size > cls.max_cache_memory: + return + + with cls.__lock: + # limit cache elements + while len(cls.__cache_dict) > cls.max_cache_elements: + # remove oldest elements from cache + # (dict pairs are sorted according to insertion order) + cls.__cache_dict.pop(next(iter(cls.__cache_dict))) + + # limit cache memory + while value.itemsize * value.size + CodebookCache.__cache_memory() > cls.max_cache_memory: + # remove oldest elements from cache + # (dict pairs are sorted according to insertion order) + cls.__cache_dict.pop(next(iter(cls.__cache_dict))) + + # store the new (key,value) pair + cls.__cache_dict[key] = value + + @classmethod + def __cache_memory(cls) -> int: + """ + Return size of CodebookCache values in bytes. + (internal, not to be called from outside) + """ + ans = 0 + for key, value in cls.__cache_dict.items(): + ans += value.itemsize * value.size + return ans + + @classmethod + def __clear_cache(cls) -> None: + """ + Clear the cache, remove the data. + """ + with cls.__lock: + cls.__cache_dict.clear() + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + if not isinstance(events, list): + events = [events] + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + target_frame_rate: int = 50, + at_KD: bool = False, + sv_KD: bool = False, + enable_cache: bool = True, + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param enable_cache: Enables a cache for the codebook indexes + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.target_frame_rate = target_frame_rate + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + self.enable_cache = enable_cache + if self.enable_cache: + CodebookCache.enable() + assert CodebookCache.enabled() + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + # validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + # cuts_pre_mixed = fix_start(cuts_pre_mixed) + + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + target_frame_rate=self.target_frame_rate, + pad_value=-100, + ) + + if self.at_KD: + # at_targets = collate_custom_field( + # cuts_pre_mixed, "beats_embedding", pad_value=-100 + # ) # (N,C) + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [getattr(c.supervisions[0], "audio_event", "0") for c in cuts_pre_mixed] # the label indices are in CED format + at_targets, _ = str2multihot(audio_events) # (N, num_events) + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + +def fix_start(cuts): + # make the start of codebook indexes the same as the cut + new_cuts = [] + for cut in cuts: + if cut.has_custom("codebook_indexes"): + cut.codebook_indexes.start = cut.start + if cut.has_custom("firered_codebook_indexes"): + cut.firered_codebook_indexes.start = cut.start + new_cuts.append(cut) + return new_cuts + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.codebook_indexes.array.storage_key != "dummy_whisper_codebook_indexes_1510" + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.beats_embedding.storage_key != "dummy_beats_embedding" + +def load_codebook_indexes(c): + info = c.codebook_indexes + cached_cb = CodebookCache.try_cache(c.supervisions[0].id) # we use supervision ID rather than cut id because cuts.repeat() ruins the cut id + if cached_cb is not None: + return cached_cb + else: + if isinstance(info, dict): + filename = info["path"] + with open(filename, "rb") as f: + cb_indexes = np.load(f) + # return np.load(filename, mmap_mode="r") + else: + cb_indexes = c.load_custom("codebook_indexes") + + CodebookCache.add_to_cache(c.supervisions[0].id, cb_indexes) + return cb_indexes + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: torch.Tensor = None, + temporal_array: bool = True, + target_frame_rate: int = 50, + pad_value=None, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * target_frame_rate) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(load_codebook_indexes(c)) if c.has_custom(field) else dummy for c in cuts # load the numpy codebook indexes + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + + +if __name__=="__main__": + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + # enable the cache + CodebookCache.enable() + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("data/vq_hubert_large_layer_21_normalize_1_cb_16/librispeech_cuts_dev-clean.jsonl.gz").subset(first=500).repeat(2) + # cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + # cuts = load_manifest("debug.jsonl.gz") + + # gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + import time + start = time.time() + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + print(f"Cache: {CodebookCache.enabled()}, Time elapsed: {time.time() - start}") + # print(gt_mvq_tokens) + + # gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + # beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + # print(beats_embed) + diff --git a/egs/emilia/CLAP/spear/dataset2_npy_cache.py b/egs/emilia/CLAP/spear/dataset2_npy_cache.py new file mode 100644 index 0000000000..9c9132db74 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset2_npy_cache.py @@ -0,0 +1,459 @@ +import math +import random +from threading import Lock +from typing import Callable, Dict, List, Optional, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut, Cut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +from lhotse.cut.set import mix + +class CodebookCache: + """ + Cache of 'bytes' objects with audio data. + It is used to cache the "command" type audio inputs. + + By default it is disabled, to enable call `set_caching_enabled(True)` + or `AudioCache.enable()`. + + The cache size is limited to max 100 elements and 500MB of audio. + + A global dict `__cache_dict` (static member variable of class AudioCache) + is holding the codebooks as np.array. + The key is the supervision ID, we avoid using cut.id because the cut IDs could be ruined by repeat + + Thread-safety is ensured by a threading.Lock guard. + """ + + __enabled: bool = False + + max_cache_memory: int = 500 * 1e6 # 500 MB + max_cache_elements: int = 10000 # number audio files + + __cache_dict: Dict[str, np.array] = {} + __lock: Lock = Lock() + + @classmethod + def enable(cls, enabled=True): + cls.__enabled = enabled + if not enabled: + cls.__clear_cache() + + @classmethod + def enabled(cls) -> bool: + return cls.__enabled + + @classmethod + def try_cache(cls, key: str) -> Optional[bytes]: + """ + Test if 'key' is in the chache. If yes return the bytes array, + otherwise return None. + """ + + if not cls.__enabled: + return None + + with cls.__lock: + if key in cls.__cache_dict: + return cls.__cache_dict[key] + else: + return None + + @classmethod + def add_to_cache(cls, key: str, value: np.array): + """ + Add the new (key,value) pair to cache. + Possibly free some elements before adding the new pair. + The oldest elements are removed first. + """ + + if not cls.__enabled: + return None + + if value.itemsize * value.size > cls.max_cache_memory: + return + + with cls.__lock: + # limit cache elements + while len(cls.__cache_dict) > cls.max_cache_elements: + # remove oldest elements from cache + # (dict pairs are sorted according to insertion order) + cls.__cache_dict.pop(next(iter(cls.__cache_dict))) + + # limit cache memory + while value.itemsize * value.size + CodebookCache.__cache_memory() > cls.max_cache_memory: + # remove oldest elements from cache + # (dict pairs are sorted according to insertion order) + cls.__cache_dict.pop(next(iter(cls.__cache_dict))) + + # store the new (key,value) pair + cls.__cache_dict[key] = value + + @classmethod + def __cache_memory(cls) -> int: + """ + Return size of CodebookCache values in bytes. + (internal, not to be called from outside) + """ + ans = 0 + for key, value in cls.__cache_dict.items(): + ans += value.itemsize * value.size + return ans + + @classmethod + def __clear_cache(cls) -> None: + """ + Clear the cache, remove the data. + """ + with cls.__lock: + cls.__cache_dict.clear() + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + if not isinstance(events, list): + events = [events] + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + target_frame_rate: int = 50, + at_KD: bool = False, + sv_KD: bool = False, + enable_cache: bool = True, + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param enable_cache: Enables a cache for the codebook indexes + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.target_frame_rate = target_frame_rate + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + self.enable_cache = enable_cache + if self.enable_cache: + CodebookCache.enable() + assert CodebookCache.enabled() + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + # validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = fix_start(cuts_pre_mixed) + + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + target_frame_rate=self.target_frame_rate, + pad_value=-100, + ) + + if self.at_KD: + # at_targets = collate_custom_field( + # cuts_pre_mixed, "beats_embedding", pad_value=-100 + # ) # (N,C) + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [getattr(c.supervisions[0], "audio_event", "0") for c in cuts_pre_mixed] # the label indices are in CED format + # at_targets, _ = str2multihot(audio_events) # (N, num_events) + at_targets = None + + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + +def fix_start(cuts): + # make the start of codebook indexes the same as the cut + new_cuts = [] + for cut in cuts: + if cut.has_custom("codebook_indexes") and (not isinstance(cut.codebook_indexes, dict)): + cut.codebook_indexes.start = cut.start + if cut.has_custom("firered_codebook_indexes") and (not isinstance(cut.firered_codebook_indexes, dict)): + cut.firered_codebook_indexes.start = cut.start + new_cuts.append(cut) + return new_cuts + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.codebook_indexes.array.storage_key != "dummy_whisper_codebook_indexes_1510" + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.beats_embedding.storage_key != "dummy_beats_embedding" + +def load_codebook_indexes(c): + info = c.codebook_indexes + cached_cb = CodebookCache.try_cache(c.supervisions[0].id) # we use supervision ID rather than cut id because cuts.repeat() ruins the cut id + if cached_cb is not None: + return cached_cb + else: + if isinstance(info, dict): + filename = info["path"] + with open(filename, "rb") as f: + cb_indexes = np.load(f) + # return np.load(filename, mmap_mode="r") + else: + cb_indexes = c.load_custom("codebook_indexes") + + CodebookCache.add_to_cache(c.supervisions[0].id, cb_indexes) + return cb_indexes + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: torch.Tensor = None, + temporal_array: bool = True, + target_frame_rate: int = 50, + pad_value=None, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * target_frame_rate) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(load_codebook_indexes(c)) if c.has_custom(field) else dummy for c in cuts # load the numpy codebook indexes + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + + + +if __name__=="__main__": + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + # enable the cache + CodebookCache.enable() + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("data/vq_hubert_large_layer_21_normalize_1_cb_16/librispeech_cuts_dev-clean.jsonl.gz").subset(first=500).repeat(2) + # cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + # cuts = load_manifest("debug.jsonl.gz") + + # gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + import time + start = time.time() + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + print(f"Cache: {CodebookCache.enabled()}, Time elapsed: {time.time() - start}") + # print(gt_mvq_tokens) + + # gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + # beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + # print(beats_embed) + diff --git a/egs/emilia/CLAP/spear/dataset_at.py b/egs/emilia/CLAP/spear/dataset_at.py new file mode 100644 index 0000000000..4fa0a2a4f7 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset_at.py @@ -0,0 +1,314 @@ +from typing import Callable, Dict, List, Union +import random + +import numpy as np + +from lhotse import CutSet, load_manifest, load_manifest_lazy +from lhotse import Fbank, FbankConfig +from lhotse.dataset import CutMix +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures, OnTheFlyFeatures +from lhotse.dataset.collation import read_audio_from_cuts, collate_matrices +from lhotse.cut import MonoCut +from lhotse.utils import LOG_EPSILON, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +import torch +import torch.utils +from torch.utils.data.dataloader import DataLoader, default_collate + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + if not isinstance(events, list): + events = [events] + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + +class MultiTaskDataset(torch.utils.data.Dataset): + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + mixup_cuts: CutSet = None, + mixup_prob: float = 0.5, + mvq_KD: bool = False, + at_KD: bool = False, + sv_KD: bool = False + ): + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + self.extractor = Fbank(FbankConfig(num_mel_bins=128)) + + self.mvq_KD = mvq_KD + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.mixup_cuts = mixup_cuts + self.mixup_prob = mixup_prob + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + audios, cuts, mix_labels = self.read_and_mix_audio(cuts, p=self.mixup_prob) + + inputs, input_lens = compute_feature(audios, cuts, self.extractor) + + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + # cuts_pre_mixed = fix_start(cuts_pre_mixed) + + if self.mvq_KD: + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + pad_value=-100, + ) + else: + mvq_tokens = None + mvq_token_lens = None + + if self.at_KD: + # at_targets = collate_custom_field( + # cuts_pre_mixed, "beats_embedding", pad_value=-100 + # ) # (N,C) + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + at_targets = mix_labels + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": mvq_tokens, + "cb_indexes_len": mvq_token_lens, + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + + def read_and_mix_audio(self, cuts: CutSet, p: float=0.5): + audios = [] + out_cuts = [] + labels = [] + for cut in cuts: + # mix the audio + if random.random() < self.mixup_prob and self.mixup_cuts is not None: + mix_cut = self.mixup_cuts.sample(n_cuts=1) + audio, label = _read_and_mix_audio_single(cut, mix_cut) + else: + audio = cut.load_audio() + label, _ = str2multihot(cut.supervisions[0].audio_event) + audios.append(audio) + out_cuts.append(cut) + labels.append(label) + + labels = torch.cat(labels, dim=0) # (B,num_classes) + + return audios, CutSet.from_cuts(out_cuts), labels + +def _read_and_mix_audio_single(cut, mix_cut): + mix_lambda = np.random.beta(10,10) + audio1 = cut.load_audio() + audio2 = mix_cut.load_audio() + if audio1.shape[1] > audio2.shape[1]: + diff = audio1.shape[1] - audio2.shape[1] + padding = np.zeros((1, diff), dtype=np.float32) + audio2 = np.concatenate((audio2, padding), axis=1) + else: + audio2 = audio2[:, :audio1.shape[1]] + + # mix the audio waveform + mix_audio = audio1 * mix_lambda + audio2 * (1 - mix_lambda) + + # mix the label + label1, _ = str2multihot(cut.supervisions[0].audio_event) + label2, _ = str2multihot(mix_cut.supervisions[0].audio_event) + mix_label = label1 * mix_lambda + label2 * (1 - mix_lambda) + + return mix_audio, mix_label + +def compute_feature(audios, cuts, extractor): + # compute features given the audios + # cuts is only for metadata reading + features_single = [] + for idx, (audio, cut) in enumerate(zip(audios, cuts)): + try: + features = extractor.extract(audio, cuts[idx].sampling_rate) + except: + print( + f"Error while extracting the features for cut with ID {cut.id} -- details:\n{cut}" + ) + raise + features_single.append(torch.from_numpy(features)) + + features_batch = collate_matrices(features_single, padding_value=LOG_EPSILON) + + feature_lens = torch.tensor( + [f.shape[0] for f in features_single], dtype=torch.int64 + ) + + out = (features_batch, feature_lens) + return out + + +def load_codebook_indexes(c): + info = c.codebook_indexes + if isinstance(info, dict): + filename = info["path"] + return np.load(filename) + else: + return c.load_custom("codebook_indexes") + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: np.array = None, + temporal_array: bool = True, + pad_value=None, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * 50) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(load_codebook_indexes(c)) if c.has_custom(field) else dummy for c in cuts # load the numpy codebook indexes + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + +def test_dataset(): + mixup_cuts = load_manifest("data/fbank_as_ced_mAP50/audioset_cuts_balanced.jsonl.gz").drop_features() + dataset = MultiTaskDataset( + return_cuts=True, + mixup_cuts=mixup_cuts, + mixup_prob=0.5, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + ) + + cuts = load_manifest("data/fbank_as_ced_mAP50/audioset_cuts_balanced.jsonl.gz").drop_features() + cuts = cuts.subset(first=5) + batch = dataset[cuts] + print(batch) + + +def test_mix(): + musan_cuts = load_manifest("data/fbank/musan_cuts.jsonl.gz").drop_features() + noise_cuts = CutSet.from_cuts([musan_cuts[0]]) + + transform = CutMix(cuts=noise_cuts, p=1.0, snr=0, preserve_id=True) + + audio_cuts = load_manifest_lazy("data/fbank_as_ced_mAP50/audioset_cuts_balanced.jsonl.gz").drop_features() + cuts = audio_cuts.subset(first=10) + + mixed_cuts = transform(cuts) + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + noise_audio = noise_cuts[0].load_audio() + + extractor = Fbank(FbankConfig(num_mel_bins=128)) + + for mixed_cut, pre_mixed_cut in zip(mixed_cuts, cuts_pre_mixed): + + mixed_audio = mixed_cut.load_audio() + orig_audio = pre_mixed_cut.load_audio() + audio_diff = mixed_audio - orig_audio + print(mixed_audio) + +def test_read_audio(): + audio_cuts = load_manifest_lazy("data/fbank_as_ced_mAP50/audioset_cuts_balanced.jsonl.gz").drop_features() + cuts = audio_cuts.subset(first=10) + + audios, cuts = read_audio_from_cuts(cuts) + + print(audios) + print(cuts) + + +if __name__=="__main__": + test_dataset() \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/dataset_multi_speech_mvq.py b/egs/emilia/CLAP/spear/dataset_multi_speech_mvq.py new file mode 100644 index 0000000000..310c3a7fae --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset_multi_speech_mvq.py @@ -0,0 +1,336 @@ +import math +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + at_KD: bool = False, + sv_KD: bool = False + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + # cuts_pre_mixed = fix_start(cuts_pre_mixed) + #mvq_tokens, mvq_token_lens = collate_custom_field(cuts_pre_mixed, "codebook_indexes", pad_value=-100) + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + pad_value=-100, + frame_rate=50, + ) + mvq_tokens2, mvq_token2_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes2", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + pad_value=-100, + frame_rate=25, + ) + + if self.at_KD: + # at_targets = collate_custom_field( + # cuts_pre_mixed, "beats_embedding", pad_value=-100 + # ) # (N,C) + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [getattr(c.supervisions[0], "audio_event", "0") for c in cuts_pre_mixed] # the label indices are in CED format + at_targets, _ = str2multihot(audio_events) # (N, num_events) + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": [mvq_tokens, mvq_tokens2], + "cb_indexes_len": [mvq_token_lens, mvq_token2_lens], + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + +def fix_start(cuts): + # make the start of codebook indexes the same as the cut + new_cuts = [] + for cut in cuts: + if cut.has_custom("codebook_indexes"): + cut.codebook_indexes.start = cut.start + if cut.has_custom("firered_codebook_indexes"): + cut.firered_codebook_indexes.start = cut.start + new_cuts.append(cut) + return new_cuts + + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.has_custom("codebook_indexes") or cut.has_custom("firered_codebook_indexes") + elif cut.task_id == 2: + # audio cuts, should have audio logits + assert cut.has_custom("beats_embedding") + +def load_codebook_indexes(c, field: str = "codebook_indexes"): + info = getattr(c, field) + if isinstance(info, dict): + filename = info["path"] + return np.load(filename) + else: + return c.load_custom(field) + + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: np.array = None, + temporal_array: bool = True, + pad_value=None, + frame_rate: int = 50, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * frame_rate) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(load_codebook_indexes(c, field)) if c.has_custom(field) else dummy for c in cuts + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + + +if __name__=="__main__": + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("debug.jsonl.gz") + cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + cuts = load_manifest("debug.jsonl.gz") + + gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + import pdb; pdb.set_trace() + print(gt_mvq_tokens) + + gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + print(beats_embed) + diff --git a/egs/emilia/CLAP/spear/dataset_speech_audio_mvq.py b/egs/emilia/CLAP/spear/dataset_speech_audio_mvq.py new file mode 100644 index 0000000000..f5dd099e94 --- /dev/null +++ b/egs/emilia/CLAP/spear/dataset_speech_audio_mvq.py @@ -0,0 +1,404 @@ +import math +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate +import numpy as np + +from lhotse import validate +from lhotse.cut import CutSet, MonoCut +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.dataset.collation import collate_custom_field +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # generate multi-hot class labels + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[l] for l in label] + out[i, label] = 1 + + return out, labels + + +class MultiTaskKDDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the multi task speech and audio processing. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + at_KD: bool = False, + sv_KD: bool = False, + speech_target_frame_rate: int = 50, + num_cb_speech: int = 16, + audio_target_frame_rate: int = 25, + num_cb_audio: int = 16, + batch_duration_threshold: int = 2000, + ): + """ + IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param speech_target_frame_rate: The label frame rate for speech data. + :param audio_target_frame_rate: The label frame rate for audio data + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + self.at_KD = at_KD + self.sv_KD = sv_KD + + self.speech_target_frame_rate = speech_target_frame_rate + self.audio_target_frame_rate = audio_target_frame_rate + self.dummy_codebook_indexes = torch.ones(1510, num_cb_speech) * (-100) + self.dummy_audio_codebook_indexes = torch.ones(1510, num_cb_audio) * (-100) + self.dummy_audio_logits = torch.ones(527) * 0.5 + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + # a strict constraint on the total duration (after padding), if + # a batch exceeds this limit, will remove the last (short) few cuts + # until the total duration is under this limit + # This is especially helpful for Zipsampler, as one sampler could yield + # a lot of cuts with short length, while the other with cuts that are very + # long, making the total batch extremly large! + self.batch_duration_threshold = batch_duration_threshold + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_multi_kd(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + cuts = filter_cuts_by_duration(cuts, batch_duration_threshold=self.batch_duration_threshold) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + assert inputs.shape[0] == len(cuts) + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + # MVQ tokens + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = fix_start(cuts_pre_mixed) + assert len(cuts_pre_mixed) == len(cuts) + + # load speech indexes + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "codebook_indexes", + dummy=self.dummy_codebook_indexes, + temporal_array=True, + pad_value=-100, + frame_rate=self.speech_target_frame_rate, + ) + + # load audio cb indexes + audio_mvq_tokens, audio_mvq_token_lens = _collate_custom_field( + cuts_pre_mixed, + "audio_codebook_indexes", + dummy=self.dummy_audio_codebook_indexes, + temporal_array=True, + pad_value=-100, + frame_rate=self.audio_target_frame_rate, + ) + + if self.at_KD: + at_targets = _collate_custom_field( + cuts_pre_mixed, "beats_embedding", dummy=self.dummy_audio_logits, temporal_array=False + ) # (N,C) + else: + audio_events = [getattr(c.supervisions[0], "audio_event", "0") for c in cuts_pre_mixed] # the label indices are in CED format + at_targets, _ = str2multihot(audio_events) # (N, num_events) + + sv_targets = None + + # task ids + task_ids = [c.task_id for c in cuts_pre_mixed] + task_ids = torch.tensor(task_ids) + + dummy_text = "This is dummy text." + + batch = { + "inputs": inputs, + "cb_indexes": [mvq_tokens, audio_mvq_tokens], + "cb_indexes_len": [mvq_token_lens, audio_mvq_token_lens], + "supervisions": default_collate( + [ + { + "text": supervision.text if supervision.text is not None else dummy_text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + "task_ids": task_ids, + "at_targets": at_targets, + "sv_targets": sv_targets, + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch + +def filter_cuts_by_duration(cuts: CutSet, batch_duration_threshold: int = 2000) -> CutSet: + # the cuts are sorted by duration in decending order + max_duration = cuts[0].duration + num_cuts = len(cuts) + max_cuts = int(batch_duration_threshold // max_duration) + if max_cuts < num_cuts: + return cuts.subset(first=max_cuts) + else: + return cuts + +def fix_start(cuts): + # make the start of codebook indexes the same as the cut + new_cuts = [] + for cut in cuts: + if cut.has_custom("codebook_indexes"): + if not isinstance(cut.codebook_indexes, dict): + cut.codebook_indexes.start = cut.start + if cut.has_custom("audio_codebook_indexes"): + if not isinstance(cut.audio_codebook_indexes, dict): + cut.audio_codebook_indexes.start = cut.start + new_cuts.append(cut) + return new_cuts + + +def validate_multi_kd(cuts: CutSet) -> None: + for cut in cuts: + # assert cut.has_features, cut + assert cut.has_custom("task_id") + if cut.task_id == 1: + # speech cuts, should have codebook indexes + assert cut.has_custom("codebook_indexes") + elif cut.task_id == 2: + assert cut.has_custom("audio_codebook_indexes") + # assert cut.has_custom("beats_embedding") + +def load_codebook_indexes(c, field: str = "codebook_indexes"): + info = getattr(c, field) + if isinstance(info, dict): + filename = info["path"] + return np.load(filename) + else: + return c.load_custom(field) + +def _collate_custom_field( + cuts: CutSet, + field: str, + dummy: np.array = None, + temporal_array: bool = True, + pad_value=None, + frame_rate: int = 50, +): + + # by default, we assert the frame_shift is 0.02 + if temporal_array: + max_frames = [int(c.duration * frame_rate) for c in cuts] + + temporal_dim = 0 + pad_value = -100 + arrs = [ + torch.from_numpy(load_codebook_indexes(c, field)) if c.has_custom(field) else dummy for c in cuts + ] + for i, arr in enumerate(arrs): + arrs[i] = arr[:max_frames[i],:] + + arr_lens = torch.tensor( + [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 + ) + largest_arr = max(arrs, key=torch.numel) + maxlen = largest_arr.shape[temporal_dim] + collated_shape = (len(arrs), *largest_arr.shape) + dtype = largest_arr.dtype + if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): + dtype = torch.int64 + tensors = pad_value * torch.ones(collated_shape, dtype=dtype) + for aidx, a in enumerate(arrs): + alen = a.shape[temporal_dim] + # Construct an index expression such as tensors[:, :alen, :, :] programmatically; + # All indices are set to ':', besides temporal dim which is determined on pad_direction. + + temporal_slice = slice(0, alen) + indices = (aidx,) + tuple( + temporal_slice if i == temporal_dim else slice(None, None, None) + for i in range(len(a.shape)) + ) + tensors[indices] = a + + return tensors, arr_lens + else: + all_arrays = [torch.from_numpy(c.load_custom(field)) if c.has_custom(field) else dummy for c in cuts] + return torch.stack(all_arrays) + +def _test_filter_cuts(): + batch = torch.load("zipformer_audio_encoder/exp-96M-zipformer-lh-large-giga-xl-emo-1-as-full-music4all-w2v2-mask-p-0.5-len-10-channel-mask-p-0.25-len-15-multi-mvq-hubert-large-cb16-1.0-dasheng-as-cb8-0.2-shar-md600/batch-bdd640fb-0667-1ad1-1c80-317fa3b1799d.pt") + cuts = batch["supervisions"]["cut"] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts = CutSet.from_cuts(cuts_pre_mixed) + + print(f"Length before filtering: {len(cuts)}") + filter_cuts = filter_cuts_by_duration(cuts, 2000) + print(f"Length after filtering: {len(filter_cuts)}") + assert type(cuts) == type(filter_cuts) + +def _test(): + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import load_manifest + + dummy_codebook_indexes = torch.ones(1510, 16) * (-100) + dummy_audio_logits = torch.ones(527) * 0.5 + + cuts = load_manifest("debug.jsonl.gz") + cut_ids = [c.task_id for c in cuts] + + augmented_cuts = cuts.map(partial(_add_dummy_embeddings_and_taskIDs, None)) + cuts = load_manifest("debug.jsonl.gz") + + gt_mvq_tokens, gt_mvq_token_lens = collate_custom_field(augmented_cuts, "codebook_indexes", pad_value=-100) + mvq_tokens, mvq_token_lens = _collate_custom_field( + cuts, + "codebook_indexes", + dummy=dummy_codebook_indexes, + temporal_array=True, + pad_value=-100 + ) + print(gt_mvq_tokens) + + gt_beats_embed = collate_custom_field(augmented_cuts, "beats_embedding") + beats_embed = _collate_custom_field(cuts, "beats_embedding", dummy=dummy_audio_logits, temporal_array=False) + + print(beats_embed) + +def _test2(): + from lhotse.dataset.input_strategies import OnTheFlyFeatures + from lhotse import Fbank, FbankConfig + input_strategy = OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + + dataset = MultiTaskKDDataset( + return_cuts=True, + input_strategy=input_strategy, + num_cb_speech=16, + num_cb_audio=8, + ) + bad_batch = torch.load("zipformer_audio_encoder/exp-316M-zipformer-lh-large-giga-xl-emo-1-as-full-music4all-vgg-bbc-freesound-w2v2-mask-p-0.65-len-10-channel-mask-p-0.25-len-20-multi-mvq-hubert-large-cb16-1.0-dasheng-as-cb8-0.3-shar-md400/batch-bdd640fb-0667-1ad1-1c80-317fa3b1799d.pt") + sup = bad_batch["supervisions"] + cuts = CutSet.from_cuts(sup["cut"]) + import pdb; pdb.set_trace() + batch = dataset.__getitem__(cuts) + print(batch) + + +if __name__=="__main__": + # _test_filter_cuts() + _test2() + + + + + diff --git a/egs/emilia/CLAP/spear/decode.py b/egs/emilia/CLAP/spear/decode.py new file mode 100644 index 0000000000..ba645ffeed --- /dev/null +++ b/egs/emilia/CLAP/spear/decode.py @@ -0,0 +1,1090 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from functools import partial + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from finetune_mtl import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +from utils import _add_task_id + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = MultiTaskDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_clean_cuts = test_clean_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + test_other_cuts = librispeech.test_other_cuts() + test_other_cuts = test_other_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/decode_byte.py b/egs/emilia/CLAP/spear/decode_byte.py new file mode 100644 index 0000000000..6157e86205 --- /dev/null +++ b/egs/emilia/CLAP/spear/decode_byte.py @@ -0,0 +1,1157 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from functools import partial + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from finetune_mtl import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +from icefall import smart_byte_decode +from utils import _add_task_id + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + parser.add_argument( + "--test-libri", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--test-wenet", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--test-aishell", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append("".join(smart_byte_decode(hyp).split())) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = MultiTaskDataModule(args) + + def remove_short(c): + if c.duration < 0.3: + return False + return True + + test_sets = [] + test_dls = [] + + if params.test_libri: + test_clean_cuts = librispeech.test_clean_cuts() + test_clean_cuts = test_clean_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + test_other_cuts = librispeech.test_other_cuts() + test_other_cuts = test_other_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + + test_dls.append(librispeech.test_dataloaders(test_clean_cuts)) + test_dls.append(librispeech.test_dataloaders(test_other_cuts)) + + test_sets += ["ls-test-clean", "ls-test-other"] + + if params.test_wenet: + wenet_test_net_cuts = librispeech.wenetspeech_test_net_cuts() + wenet_test_net_cuts = wenet_test_net_cuts.filter(remove_short) + wenet_test_net_cuts = wenet_test_net_cuts.map(partial(_add_task_id, 1)) + wenet_test_meeting_cuts = librispeech.wenetspeech_test_meeting_cuts() + wenet_test_meeting_cuts = wenet_test_meeting_cuts.map(partial(_add_task_id, 1)) + + test_dls.append(librispeech.test_dataloaders(wenet_test_net_cuts)) + test_dls.append(librispeech.test_dataloaders(wenet_test_meeting_cuts)) + test_sets += [ "wenet-test-net", "wenet-test-meeting"] + + if params.test_aishell: + aishell_dev_cuts = librispeech.aishell_dev_cuts() + aishell_dev_cuts = aishell_dev_cuts.map(partial(_add_task_id, 1)) + aishell_test_cuts = librispeech.aishell_test_cuts() + aishell_test_cuts = aishell_test_cuts.map(partial(_add_task_id, 1)) + + test_dls.append(librispeech.test_dataloaders(aishell_dev_cuts)) + test_dls.append(librispeech.test_dataloaders(aishell_test_cuts)) + + test_sets += [ "aishell-dev", "aishell-test"] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/decode_gigaspeech.py b/egs/emilia/CLAP/spear/decode_gigaspeech.py new file mode 100755 index 0000000000..46049e5332 --- /dev/null +++ b/egs/emilia/CLAP/spear/decode_gigaspeech.py @@ -0,0 +1,1118 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from functools import partial +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from finetune_mtl import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +from utils import _add_task_id + +LOG_EPS = math.log(1e-10) + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / f"{params.decoding_method}-giga" + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = MultiTaskDataModule(args) + + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + gigaspeech_dev_cuts = gigaspeech_dev_cuts.map(partial(_add_task_id, 1)) + gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() + gigaspeech_test_cuts = gigaspeech_test_cuts.map(partial(_add_task_id, 1)) + + dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) + test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/decoder.py b/egs/emilia/CLAP/spear/decoder.py new file mode 100644 index 0000000000..e77e541187 --- /dev/null +++ b/egs/emilia/CLAP/spear/decoder.py @@ -0,0 +1,130 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scaling import Balancer + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + ) + # the balancers are to avoid any drift in the magnitude of the + # embeddings, which would interact badly with parameter averaging. + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + + embedding_out = self.balancer(embedding_out) + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + embedding_out = self.balancer2(embedding_out) + + return embedding_out diff --git a/egs/emilia/CLAP/spear/encoder_interface.py b/egs/emilia/CLAP/spear/encoder_interface.py new file mode 120000 index 0000000000..653c5b09af --- /dev/null +++ b/egs/emilia/CLAP/spear/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/export.py b/egs/emilia/CLAP/spear/export.py new file mode 100644 index 0000000000..55b6607095 --- /dev/null +++ b/egs/emilia/CLAP/spear/export.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn + +from train_multi_KD3_shar import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + logging.info("Decoding started") + + device = torch.device("cpu") + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + load_info = model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + logging.info(load_info) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + if params.iter > 0: + out_path = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, out_path) + else: + out_path = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, out_path) + logging.info(f"Model saved to: {out_path}") + + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/spear/export_asr.py b/egs/emilia/CLAP/spear/export_asr.py new file mode 100644 index 0000000000..fc08b798e3 --- /dev/null +++ b/egs/emilia/CLAP/spear/export_asr.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn + +from finetune_mtl import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model" + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + logging.info("Decoding started") + + device = torch.device("cpu") + + logging.info(f"Device: {device}") + logging.info(params) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + load_info = model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + logging.info(load_info) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + if params.iter > 0: + out_path = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, out_path) + else: + out_path = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, out_path) + logging.info(f"Model saved to: {out_path}") + + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/emilia/CLAP/spear/extract_mvq.py b/egs/emilia/CLAP/spear/extract_mvq.py new file mode 100644 index 0000000000..cb2c64e746 --- /dev/null +++ b/egs/emilia/CLAP/spear/extract_mvq.py @@ -0,0 +1,351 @@ + +import argparse +import os +import io +import logging +from pathlib import Path + +from icefall.utils import AttributeDict, setup_logger, str2bool + +from train_multi_KD3_shar import add_model_arguments, get_encoder_embed, get_encoder_model +from collect_zipformer_embeddings import FbankDataset, ZipformerModel + +import torch +import torch.multiprocessing as mp +from torch.utils.data import DataLoader + +import lhotse +from lhotse import load_manifest, CutSet +from lhotse.cut import MonoCut +from lhotse import Fbank, FbankConfig +from lhotse.dataset import DynamicBucketingSampler +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fastcopy +import multi_quantization as quantization +import numpy as np + +from typing import Union, Optional + +lhotse.set_caching_enabled(True) + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + # quantizer related + parser.add_argument( + "--embedding-dim", + type=int, + default=512, + ) + + parser.add_argument( + "--num-cb", + type=int, + default=4, + ) + + parser.add_argument( + "--quantizer-path", + type=str, + required=True, + ) + + parser.add_argument( + "--s3-prefix", + type=str, + required=True, + default="brainllm:s3://yangxiaoyu/LibriSpeech" + ) + + # others + parser.add_argument( + "--num-jobs", + type=int, + default=1, + ) + + parser.add_argument( + "--input-manifest", + type=str, + required=True, + ) + + parser.add_argument( + "--manifest-name", + type=str, + required=True, + help="name of the manifest, e.g embeddings-dev-clean, embeddings-train-clean-100" + ) + + parser.add_argument( + "--embedding-dir", + type=str, + default="data/vq_whisper" + ) + + parser.add_argument( + "--embedding-layer", + type=int, + default=-1, + help="Which layer's representation should be extracted", + ) + + parser.add_argument( + "--max-duration", + type=int, + default=500, + ) + + parser.add_argument( + "--target-manifest-file", + type=str, + required=True, + help="Where to store the manifest augmented with zipformer features" + ) + + parser.add_argument( + "--normalize", + type=str2bool, + default=False, + help="If True, compute the channel-wise mean and std on the training se for nomalization." + ) + + # zipformer related args + parser.add_argument( + "--model-ckpt", + type=str, + required=True, + ) + + parser.add_argument( + "--zipformer-version", + type=str, + default="300m", + ) + + parser.add_argument( + "--frame-shift", + type=float, + default=0.02, + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=128, + ) + add_model_arguments(parser) + + return parser + +def normalize_data(data, mean, std): + return (data - mean) / (std + 1e-5) + +@torch.no_grad() +def extract_embeddings( + rank: int, + manifest: str, + params: AttributeDict, +): + setup_logger(f"data/vq_zipformer_client/log/log-zipformer-cb-indexes") + if params.num_jobs > 1: + manifest = manifest[rank] + output_manifest = params.embedding_dir / f"zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}-{rank}.jsonl.gz" + else: + output_manifest = params.embedding_dir / f"zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}.jsonl.gz" + + device = torch.device("cuda", rank) + + # currently only use the encoder of zipformer + logging.info(params) + model = ZipformerModel( + encoder_embed=get_encoder_embed(params), + encoder=get_encoder_model(params), + ) + state_dict = torch.load(params.model_ckpt)["model"] + load_info = model.load_state_dict(state_dict, strict=False) + logging.info(load_info) + + model.to(device) + model.eval() + logging.info(f"Number of zipformer model params: {sum(p.numel() for p in model.parameters())}") + logging.info(f"Successfully loaded zipformer model.") + + quantizer = quantization.Quantizer( + dim=params.embedding_dim, + num_codebooks=params.num_cb, + codebook_size=256, + ) + state_dict = torch.load(params.quantizer_path) + if "quantizer" not in state_dict: + # with out normalization stats + assert not params.normalize, "No normalization stats is found!" + state_dict = {"quantizer": state_dict} + + if params.normalize: + mu = state_dict["mean"].to(device) + std = state_dict["std"].to(device) + quantizer.load_state_dict(state_dict["quantizer"]) + quantizer.to(device) + + dataset = FbankDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=True + ) + + sampler = DynamicBucketingSampler( + manifest, + max_duration=params.max_duration, + shuffle=False, + num_buckets=20, + buffer_size=20 * 2000, + shuffle_buffer_size=20 * 5000, + drop_last=False, + ) + + dl = DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + new_cuts = [] + num_cuts = 0 + + logging.info(f"Writing zipformer indexes") + for i, batch in enumerate(dl): + cuts = batch["cuts"] + + with torch.cuda.amp.autocast(enabled=True): + embeddings, embedding_lens = model.get_embeddings( + batch, + layer_idx=params.embedding_layer # which layer's embedding to be stored + ) + embeddings = embeddings.float() + if params.normalize: + embeddings = normalize_data(embeddings, mu, std) + + # codebook_indexes = quantizer.encode(embeddings) # [N, T, C] + N,T,C = embeddings.shape + embeddings = embeddings.reshape(-1, C) + B = 2000 + splits = embeddings.split(B) + codebook_indexes = [] + for chunk in splits: + chunk_indexes = quantizer.encode(chunk) + codebook_indexes.append(chunk_indexes) + codebook_indexes = torch.cat(codebook_indexes).reshape(N,T,params.num_cb) + codebook_indexes = codebook_indexes.to("cpu").numpy() + assert np.min(codebook_indexes) >= 0 + assert np.max(codebook_indexes) < 256 + + for idx, cut in enumerate(cuts): + cb_index = codebook_indexes[idx][: embedding_lens[idx]] + + if "/" in cut.id: + # we are dealing with libriheavy cuts + filename = cut.id + else: + filename = "/".join(cut.id.split("-")[:2]) + "/" + cut.id + output_path = f"{params.s3_prefix}/{filename}.npy" + if os.path.exists(output_path): + logging.info(f"This codebook file has already been generated. Please check if you are doing correctly!") + + base_dir, filename = output_path.rsplit("/", 1) + os.makedirs(base_dir, exist_ok=True) + np.save(output_path, cb_index) + + info = { + "path": output_path, + "shape": list(cb_index.shape), + "frame-shift": params.frame_shift, + } + + new_cut = fastcopy( + cut, + custom={"codebook_indexes": info} + ) + new_cuts.append(new_cut) + num_cuts += 1 + if num_cuts and num_cuts % 100 == 0: + logging.info(f"Cuts processed until now: {num_cuts}") + + logging.info(f"Finished extracting zipformer codebook indexes, processed a total of {num_cuts} cuts.") + + CutSet.from_cuts(new_cuts).to_jsonl(output_manifest) + logging.info(f"Saved manifest to {output_manifest}") + +def join_manifests( + input_cuts: CutSet, + embedding_manifest: str, + output_dir: str, +): + # Combine the teacher embedding manifest with the original manifest for ASR + embedding_cuts = load_manifest(embedding_manifest) + + assert len(embedding_cuts) == len(input_cuts) + assert set(input_cuts.ids) == set(embedding_cuts.ids) + + embedding_cuts = embedding_cuts.sort_like(input_cuts) + for cut_idx, (ori_cut, embed_cut) in enumerate(zip(input_cuts, embedding_cuts)): + assert ori_cut.id == embed_cut.id + ori_cut.codebook_indexes = embed_cut.codebook_indexes + + input_cuts.to_jsonl(output_dir) + logging.info(f"Saved the joined manifest to {output_dir}") + +def remove_short_and_long_utt(c): + if c.duration < 1.0 or c.duration > 29.9: + return False + return True + +def remove_sp(c): + if "sp0.9" in c.id or "sp1.1" in c.id: + return False + return True + +if __name__=="__main__": + parser = get_parser() + args = parser.parse_args() + params = AttributeDict() + params.update(vars(args)) + params.embedding_dir = Path(params.embedding_dir) + + nj = params.num_jobs + print(f"Start loading manifest") + cuts = load_manifest(params.input_manifest) + cuts = cuts.filter(remove_short_and_long_utt) # remove audio longer than 30s + cuts = cuts.filter(remove_sp) # remove speed perturb + print(f"Finished loading manifest") + print(cuts) + + embedding_manifest = params.embedding_dir / f"zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}.jsonl.gz" + + if not embedding_manifest.exists(): + if nj == 1: + extract_embeddings( + rank=0, + manifest=cuts, + params=params, + ) + else: + splitted_cuts = cuts.split(num_splits=nj) + logging.info(f"Finished splitting manifest") + mp.spawn(extract_embeddings, args=(splitted_cuts, params), nprocs=nj, join=True) + manifests = f"{str(params.embedding_dir)}/zipformer-{params.zipformer_version}-layer-{params.embedding_layer}-{params.manifest_name}-*.jsonl.gz" + os.system(f"lhotse combine {manifests} {embedding_manifest}") + else: + logging.info(f"Skip embedding extraction: the manifest is already generated.") + + output_manifest = params.target_manifest_file + if not os.path.exists(output_manifest): + join_manifests( + input_cuts=cuts, + embedding_manifest=embedding_manifest, + output_dir=output_manifest, + ) + + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/finetune.py b/egs/emilia/CLAP/spear/finetune.py new file mode 100644 index 0000000000..107f268b18 --- /dev/null +++ b/egs/emilia/CLAP/spear/finetune.py @@ -0,0 +1,1553 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +import random +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_asr import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) +from utils import compare_model + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + 100000 + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used" + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true" + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info(f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}.") + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + freeze_encoder=freeze_encoder, + ) + simple_loss, pruned_loss, ctc_loss = losses[:3] + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + # resuming training + assert params.start_epoch > 1, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + logging.info(f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}") + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict, world_size=world_size, rank=rank + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + + valid_sets = ["librispeech"] + valid_dls = [ + librispeech.valid_dataloaders(valid_cuts, world_size=world_size, rank=rank), + ] + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/finetune_at.py b/egs/emilia/CLAP/spear/finetune_at.py new file mode 100644 index 0000000000..558dd53cba --- /dev/null +++ b/egs/emilia/CLAP/spear/finetune_at.py @@ -0,0 +1,1955 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from functools import partial +import random +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from at_datamodule import MultiTaskDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_asr import MultiTaskModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer2, SimpleDownsample + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) +from utils import ( + compare_model, + upper_only_alpha, + normalize_chinese_text, + normalize_english_text, + MetricsTracker, + _add_task_id, + map_zh, + setup_distributed, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + batch_count = ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + if params.large_batch_count: + batch_count += 100000 + return batch_count + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used" + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true; -1 means not freezing" + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--post-encoder-downsampling-factor", + type=int, + default=1, + help="The ds factor after the zipformer encoder", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-asr", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + ) + + parser.add_argument( + "--num-events", + type=int, + default=527, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="The initial value of warmup, between 0 and 1" + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0, + help="The number of batches for lr warmup" + ) + + parser.add_argument( + "--large-batch-count", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if params.output_downsampling_factor == 2: + assert params.post_encoder_downsampling_factor == 1, "CANNOT perform double output downsample!" + + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + +def get_encoder_downsample_module(params: AttributeDict) -> nn.Module: + if params.post_encoder_downsampling_factor > 1: + downsample_module = SimpleDownsample( + max(_to_int_tuple(params.encoder_dim)), + downsample=params.post_encoder_downsampling_factor, + dropout=0.0, + ) + else: + downsample_module = None + return downsample_module + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + # assert params.use_transducer or params.use_ctc, ( + # f"At least one of them should be True, " + # f"but got params.use_transducer={params.use_transducer}, " + # f"params.use_ctc={params.use_ctc}" + # ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + post_encoder_downsample = get_encoder_downsample_module(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + assert params.causal == False + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + model = MultiTaskModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_downsample=post_encoder_downsample, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + num_events=params.num_events, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + feature_lens = supervisions["num_frames"].to(device) + + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.02 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + at_targets = batch["at_targets"] if params.do_audio_tagging else None + if at_targets is not None: + at_targets = at_targets.to(device) + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info(f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}.") + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + at_targets=at_targets, + freeze_encoder=freeze_encoder, + skip_asr=not params.do_asr, + ) + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, audio_tagging_loss = losses + + loss = 0.0 + + # ASR related loss + asr_mask = task_ids == 1 + if params.use_transducer and params.do_asr: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + simple_loss = (simple_loss * asr_mask).sum() + pruned_loss = (pruned_loss * asr_mask).sum() + + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_attention_decoder and params.do_asr: + attention_decoder_loss = (attention_decoder_loss * asr_mask).sum() + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + if params.use_ctc and params.do_asr: + ctc_loss = (ctc_loss * asr_mask).sum() + loss += params.ctc_loss_scale * ctc_loss + + if params.do_audio_tagging: + at_mask = task_ids == 2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * at_mask).sum() + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer and params.do_asr: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc and params.do_asr: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder and params.do_asr: + info["attention_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + if params.use_shar: + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = ["/".join(str(c.shard_origin).split("/")[1:3]) for c in cuts] + unique_origin = set(shard_origin) + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if params.use_shar: + cur_batch_idx = params.batch_idx_train + else: + cur_batch_idx = batch_idx + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {cur_batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + rank = setup_distributed() + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + if params.start_epoch > 1: + # resuming training + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + # training from scratch + checkpoints =None + + # Setting the encoder lr scale + logging.info(f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}") + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + warmup_batches=params.warmup_batches, + warmup_start=params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + train_cuts = {} + train_cuts_duration = [] + + assert params.do_asr or params.do_audio_tagging, "At least perform on task!" + + assert not params.do_asr, "This script is only for AT finetuning" + if params.do_asr: + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts = libriheavy_cuts.map(normalize_english_text) + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 500 * 0.9, + "medium": 3687, + "large": 37218, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + mls_cuts = mls_cuts.map(normalize_english_text) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts = wenetspeech_cuts.map(map_zh) + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 14621270, + } + wenetspeech_cuts_duration = { + "S": 100, # 100 hrs + "M": 1000, # 1000 hrs + "L": 10000, # 10000 hrs + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset] * params.repeat_wenetspeech) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset] * params.repeat_wenetspeech) + + if params.use_aishell: + aishell_cuts = librispeech.aishell_train_cuts() + aishell_cuts = aishell_cuts.map(map_zh) + aishell_cuts = aishell_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + # aishell stats: 170 hrs, 120098 cuts + asr_training_cuts.append(aishell_cuts) + asr_training_cuts_lens.append(120098) + asr_training_cuts_duration.append(150) + + if params.use_extra_chinese_dataset: + chinese_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chinese_cuts = chinese_cuts.map(partial(_add_task_id, 1)) + chinese_cuts = chinese_cuts.map(normalize_chinese_text) + chinese_cuts = chinese_cuts.map(map_zh) + asr_training_cuts.append(chinese_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + englishs_cuts = englishs_cuts.map(normalize_english_text) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + # combine the asr data + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=True, + ) + else: + asr_training_cuts = asr_training_cuts[0] + asr_training_cuts_duration = sum(asr_training_cuts_duration) + num_asr_cuts = sum(asr_training_cuts_lens) + + if params.on_the_fly_feats: + asr_training_cuts = asr_training_cuts.drop_features() + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(asr_training_cuts_duration) + + # audio data + assert params.do_audio_tagging + if params.do_audio_tagging: + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + audioset_cuts_lens = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5000, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + if params.on_the_fly_feats: + audioset_cuts = audioset_cuts.drop_features() + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 29 seconds + if c.duration < 1.0 or c.duration > 29.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_asr_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + # construct the training dataloader + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + ) + + # TODO: add more validation sets + valid_sets = [] + valid_dls = [] + + if params.use_librispeech: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_cuts = valid_cuts.map(partial(_add_task_id, 1)) + valid_sets.append("librispeech") + valid_dls.append( + librispeech.valid_dataloaders(valid_cuts, world_size=world_size, rank=rank), + ) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("gigaspeech") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(map_zh) + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("wenetspeech") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_aishell: + aishell_dev_cuts = librispeech.aishell_dev_cuts() + aishell_dev_cuts = aishell_dev_cuts.map(map_zh) + aishell_dev_cuts = aishell_dev_cuts.map(partial(_add_task_id, 1)) + asr_aishell_valid_dl = librispeech.valid_dataloaders(aishell_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("aishell") + valid_dls.append(asr_aishell_valid_dl) + + if params.use_audioset and params.do_audio_tagging: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + if not params.use_shar: + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/finetune_ctc.py b/egs/emilia/CLAP/spear/finetune_ctc.py new file mode 100644 index 0000000000..b02f8818ee --- /dev/null +++ b/egs/emilia/CLAP/spear/finetune_ctc.py @@ -0,0 +1,1968 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from functools import partial +import random +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_asr import MultiTaskModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer2, SimpleDownsample + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) +from utils import ( + compare_model, + upper_only_alpha, + normalize_chinese_text, + normalize_english_text, + MetricsTracker, + _add_task_id, + map_zh, + setup_distributed, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + batch_count = ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + if params.large_batch_count: + batch_count += 100000 + return batch_count + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used" + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true; -1 means not freezing" + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--post-encoder-downsampling-factor", + type=int, + default=1, + help="The ds factor after the zipformer encoder", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-asr", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + ) + + parser.add_argument( + "--num-events", + type=int, + default=527, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + # normalization + parser.add_argument( + "--normalize-fbank", + type=str2bool, + default=False, + help="If perform normalization to the input fbank features" + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="The initial value of warmup, between 0 and 1" + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0, + help="The number of batches for lr warmup" + ) + + parser.add_argument( + "--large-batch-count", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if params.output_downsampling_factor == 2: + assert params.post_encoder_downsampling_factor == 1, "CANNOT perform double output downsample!" + + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + +def get_encoder_downsample_module(params: AttributeDict) -> nn.Module: + if params.post_encoder_downsampling_factor > 1: + downsample_module = SimpleDownsample( + max(_to_int_tuple(params.encoder_dim)), + downsample=params.post_encoder_downsampling_factor, + dropout=0.0, + ) + else: + downsample_module = None + return downsample_module + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + post_encoder_downsample = get_encoder_downsample_module(params) + if params.output_downsampling_factor == 1: + params.subsampling_factor = 2 + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + assert params.causal == False + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + if params.normalize_fbank: + logging.info("Normalizing the input fbank features") + + model = MultiTaskModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_downsample=post_encoder_downsample, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + num_events=params.num_events, + normalize_fbank=params.normalize_fbank, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + feature_lens = supervisions["num_frames"].to(device) + + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.02 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + at_targets = batch["at_targets"] if params.do_audio_tagging else None + if at_targets is not None: + at_targets = at_targets.to(device) + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info(f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}.") + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + at_targets=at_targets, + freeze_encoder=freeze_encoder, + ) + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, audio_tagging_loss = losses + + loss = 0.0 + + # ASR related loss + asr_mask = task_ids == 1 + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + simple_loss = (simple_loss * asr_mask).sum() + pruned_loss = (pruned_loss * asr_mask).sum() + + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_attention_decoder: + attention_decoder_loss = (attention_decoder_loss * asr_mask).sum() + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + if params.use_ctc: + ctc_loss = (ctc_loss * asr_mask).sum() + loss += params.ctc_loss_scale * ctc_loss + + if params.do_audio_tagging: + at_mask = task_ids == 2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * at_mask).sum() + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attention_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + if params.use_shar: + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = ["/".join(str(c.shard_origin).split("/")[1:3]) for c in cuts] + unique_origin = set(shard_origin) + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if params.use_shar: + cur_batch_idx = params.batch_idx_train + else: + cur_batch_idx = batch_idx + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {cur_batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + if params.start_epoch > 1: + # resuming training + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + # training from scratch + checkpoints =None + + # Setting the encoder lr scale + logging.info(f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}") + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + warmup_batches=params.warmup_batches, + warmup_start=params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + train_cuts = {} + train_cuts_duration = [] + + assert params.do_asr or params.do_audio_tagging, "At least perform on task!" + + if params.do_asr: + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts = libriheavy_cuts.map(normalize_english_text) + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 500 * 0.9, + "medium": 3687, + "large": 37218, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + mls_cuts = mls_cuts.map(normalize_english_text) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts = wenetspeech_cuts.map(map_zh) + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 14621270, + } + wenetspeech_cuts_duration = { + "S": 100, # 100 hrs + "M": 1000, # 1000 hrs + "L": 10000, # 10000 hrs + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset] * params.repeat_wenetspeech) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset] * params.repeat_wenetspeech) + + if params.use_aishell: + aishell_cuts = librispeech.aishell_train_cuts() + aishell_cuts = aishell_cuts.map(map_zh) + aishell_cuts = aishell_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + # aishell stats: 170 hrs, 120098 cuts + asr_training_cuts.append(aishell_cuts) + asr_training_cuts_lens.append(120098) + asr_training_cuts_duration.append(150) + + if params.use_extra_chinese_dataset: + chinese_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chinese_cuts = chinese_cuts.map(partial(_add_task_id, 1)) + chinese_cuts = chinese_cuts.map(normalize_chinese_text) + chinese_cuts = chinese_cuts.map(map_zh) + asr_training_cuts.append(chinese_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + englishs_cuts = englishs_cuts.map(normalize_english_text) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + # combine the asr data + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=True, + ) + else: + asr_training_cuts = asr_training_cuts[0] + asr_training_cuts_duration = sum(asr_training_cuts_duration) + num_asr_cuts = sum(asr_training_cuts_lens) + + if params.on_the_fly_feats: + asr_training_cuts = asr_training_cuts.drop_features() + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(asr_training_cuts_duration) + + # audio data + if params.do_audio_tagging: + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + audioset_cuts_lens = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5000, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + if params.on_the_fly_feats: + audioset_cuts = audioset_cuts.drop_features() + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 29 seconds + if c.duration < 1.0 or c.duration > 29.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_asr_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + # construct the training dataloader + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + ) + + # TODO: add more validation sets + valid_sets = [] + valid_dls = [] + + if params.use_librispeech: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_cuts = valid_cuts.map(partial(_add_task_id, 1)) + valid_sets.append("librispeech") + valid_dls.append( + librispeech.valid_dataloaders(valid_cuts, world_size=world_size, rank=rank), + ) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("gigaspeech") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(map_zh) + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("wenetspeech") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_aishell: + aishell_dev_cuts = librispeech.aishell_dev_cuts() + aishell_dev_cuts = aishell_dev_cuts.map(map_zh) + aishell_dev_cuts = aishell_dev_cuts.map(partial(_add_task_id, 1)) + asr_aishell_valid_dl = librispeech.valid_dataloaders(aishell_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("aishell") + valid_dls.append(asr_aishell_valid_dl) + + if params.use_audioset and params.do_audio_tagging: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + if not params.use_shar: + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/finetune_mtl.py b/egs/emilia/CLAP/spear/finetune_mtl.py new file mode 100644 index 0000000000..b02f8818ee --- /dev/null +++ b/egs/emilia/CLAP/spear/finetune_mtl.py @@ -0,0 +1,1968 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from functools import partial +import random +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_asr import MultiTaskModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer2, SimpleDownsample + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) +from utils import ( + compare_model, + upper_only_alpha, + normalize_chinese_text, + normalize_english_text, + MetricsTracker, + _add_task_id, + map_zh, + setup_distributed, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + batch_count = ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + if params.large_batch_count: + batch_count += 100000 + return batch_count + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used" + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true; -1 means not freezing" + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--post-encoder-downsampling-factor", + type=int, + default=1, + help="The ds factor after the zipformer encoder", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-asr", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + ) + + parser.add_argument( + "--num-events", + type=int, + default=527, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + # normalization + parser.add_argument( + "--normalize-fbank", + type=str2bool, + default=False, + help="If perform normalization to the input fbank features" + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="The initial value of warmup, between 0 and 1" + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0, + help="The number of batches for lr warmup" + ) + + parser.add_argument( + "--large-batch-count", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if params.output_downsampling_factor == 2: + assert params.post_encoder_downsampling_factor == 1, "CANNOT perform double output downsample!" + + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + +def get_encoder_downsample_module(params: AttributeDict) -> nn.Module: + if params.post_encoder_downsampling_factor > 1: + downsample_module = SimpleDownsample( + max(_to_int_tuple(params.encoder_dim)), + downsample=params.post_encoder_downsampling_factor, + dropout=0.0, + ) + else: + downsample_module = None + return downsample_module + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + post_encoder_downsample = get_encoder_downsample_module(params) + if params.output_downsampling_factor == 1: + params.subsampling_factor = 2 + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + assert params.causal == False + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + if params.normalize_fbank: + logging.info("Normalizing the input fbank features") + + model = MultiTaskModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_downsample=post_encoder_downsample, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + num_events=params.num_events, + normalize_fbank=params.normalize_fbank, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + feature_lens = supervisions["num_frames"].to(device) + + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.02 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + at_targets = batch["at_targets"] if params.do_audio_tagging else None + if at_targets is not None: + at_targets = at_targets.to(device) + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info(f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}.") + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + at_targets=at_targets, + freeze_encoder=freeze_encoder, + ) + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, audio_tagging_loss = losses + + loss = 0.0 + + # ASR related loss + asr_mask = task_ids == 1 + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + simple_loss = (simple_loss * asr_mask).sum() + pruned_loss = (pruned_loss * asr_mask).sum() + + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_attention_decoder: + attention_decoder_loss = (attention_decoder_loss * asr_mask).sum() + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + if params.use_ctc: + ctc_loss = (ctc_loss * asr_mask).sum() + loss += params.ctc_loss_scale * ctc_loss + + if params.do_audio_tagging: + at_mask = task_ids == 2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * at_mask).sum() + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attention_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + if params.use_shar: + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = ["/".join(str(c.shard_origin).split("/")[1:3]) for c in cuts] + unique_origin = set(shard_origin) + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if params.use_shar: + cur_batch_idx = params.batch_idx_train + else: + cur_batch_idx = batch_idx + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {cur_batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + if params.start_epoch > 1: + # resuming training + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + # training from scratch + checkpoints =None + + # Setting the encoder lr scale + logging.info(f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}") + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + warmup_batches=params.warmup_batches, + warmup_start=params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + train_cuts = {} + train_cuts_duration = [] + + assert params.do_asr or params.do_audio_tagging, "At least perform on task!" + + if params.do_asr: + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts = libriheavy_cuts.map(normalize_english_text) + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 500 * 0.9, + "medium": 3687, + "large": 37218, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + mls_cuts = mls_cuts.map(normalize_english_text) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts = wenetspeech_cuts.map(map_zh) + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 14621270, + } + wenetspeech_cuts_duration = { + "S": 100, # 100 hrs + "M": 1000, # 1000 hrs + "L": 10000, # 10000 hrs + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset] * params.repeat_wenetspeech) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset] * params.repeat_wenetspeech) + + if params.use_aishell: + aishell_cuts = librispeech.aishell_train_cuts() + aishell_cuts = aishell_cuts.map(map_zh) + aishell_cuts = aishell_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + # aishell stats: 170 hrs, 120098 cuts + asr_training_cuts.append(aishell_cuts) + asr_training_cuts_lens.append(120098) + asr_training_cuts_duration.append(150) + + if params.use_extra_chinese_dataset: + chinese_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chinese_cuts = chinese_cuts.map(partial(_add_task_id, 1)) + chinese_cuts = chinese_cuts.map(normalize_chinese_text) + chinese_cuts = chinese_cuts.map(map_zh) + asr_training_cuts.append(chinese_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + englishs_cuts = englishs_cuts.map(normalize_english_text) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + # combine the asr data + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=True, + ) + else: + asr_training_cuts = asr_training_cuts[0] + asr_training_cuts_duration = sum(asr_training_cuts_duration) + num_asr_cuts = sum(asr_training_cuts_lens) + + if params.on_the_fly_feats: + asr_training_cuts = asr_training_cuts.drop_features() + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(asr_training_cuts_duration) + + # audio data + if params.do_audio_tagging: + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + audioset_cuts_lens = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5000, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + if params.on_the_fly_feats: + audioset_cuts = audioset_cuts.drop_features() + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 29 seconds + if c.duration < 1.0 or c.duration > 29.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_asr_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + # construct the training dataloader + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + ) + + # TODO: add more validation sets + valid_sets = [] + valid_dls = [] + + if params.use_librispeech: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_cuts = valid_cuts.map(partial(_add_task_id, 1)) + valid_sets.append("librispeech") + valid_dls.append( + librispeech.valid_dataloaders(valid_cuts, world_size=world_size, rank=rank), + ) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("gigaspeech") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(map_zh) + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("wenetspeech") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_aishell: + aishell_dev_cuts = librispeech.aishell_dev_cuts() + aishell_dev_cuts = aishell_dev_cuts.map(map_zh) + aishell_dev_cuts = aishell_dev_cuts.map(partial(_add_task_id, 1)) + asr_aishell_valid_dl = librispeech.valid_dataloaders(aishell_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("aishell") + valid_dls.append(asr_aishell_valid_dl) + + if params.use_audioset and params.do_audio_tagging: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + if not params.use_shar: + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/inference_audio_tagging.py b/egs/emilia/CLAP/spear/inference_audio_tagging.py new file mode 100644 index 0000000000..f544aaf9fc --- /dev/null +++ b/egs/emilia/CLAP/spear/inference_audio_tagging.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0" + +./zipformer/evaluate.py \ + --epoch 50 \ + --avg 10 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + + +""" + +import argparse +from functools import partial +import logging +from pathlib import Path +from typing import Dict + +import torch +import torch.nn as nn +from at_datamodule import MultiTaskDataModule + +try: + from sklearn.metrics import average_precision_score +except: + raise ImportError(f"Please run\n" "pip3 install -U scikit-learn") +from train_multi_KD3_shar import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool +from utils import _add_dummy_embeddings_and_taskIDs, _add_task_id + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +def inference_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +): + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3, feature.shape + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + label = batch["at_targets"] + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + cuts = supervisions["cut"] + audio_events = [c.supervisions[0].audio_event for c in cuts] + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens, return_logits=True) + # convert to probabilities between 0-1 + audio_logits = audio_logits.sigmoid().detach().cpu() + + return audio_logits, label + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict: + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + all_logits = [] + all_labels = [] + + for batch_idx, batch in enumerate(dl): + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + num_cuts += len(cut_ids) + + audio_logits, labels = inference_one_batch( + params=params, + model=model, + batch=batch, + ) + + all_logits.append(audio_logits) + all_labels.append(labels) + + if batch_idx % 20 == 1: + logging.info(f"Processed {num_cuts} cuts already.") + logging.info("Finish collecting audio logits") + + return all_logits, all_labels + + +@torch.no_grad() +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + + # ASR params + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + + params.update(vars(args)) + + params.res_dir = params.exp_dir / "inference_audio_tagging" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Evaluation started") + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info("About to create model") + + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict( + average_checkpoints(filenames), strict=True + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + ), + strict=False, + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + audioset = MultiTaskDataModule(args) + + audioset_cuts = audioset.audioset_eval_cuts() + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + + audioset_dl = audioset.valid_dataloaders(audioset_cuts) + + test_sets = ["audioset_eval"] + + logits, labels = decode_dataset( + dl=audioset_dl, + params=params, + model=model, + ) + + logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy() + labels = torch.cat(labels, dim=0).long().detach().numpy() + + # compute the metric + mAP = average_precision_score( + y_true=labels, + y_score=logits, + ) + + logging.info(f"mAP for audioset eval is: {mAP}") + + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/joiner.py b/egs/emilia/CLAP/spear/joiner.py new file mode 100644 index 0000000000..dfb0a0057b --- /dev/null +++ b/egs/emilia/CLAP/spear/joiner.py @@ -0,0 +1,67 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/emilia/CLAP/spear/kd_datamodule3.py b/egs/emilia/CLAP/spear/kd_datamodule3.py new file mode 100644 index 0000000000..2cf05d3572 --- /dev/null +++ b/egs/emilia/CLAP/spear/kd_datamodule3.py @@ -0,0 +1,885 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from dataset2 import MultiTaskKDDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # KD related + group.add_argument( + "--at-KD", + type=str2bool, + default=True, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="M", + choices=["S", "M", "L"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + choices=["small", "medium", "large"] + ) + + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = MultiTaskKDDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=True, + sv_KD=True + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 1000, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"Max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 1000, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + else: + # there is no need to use a large bucket for audio tagging + # as the duration of cuts are quite similar + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskKDDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + for subset in gigaspeech_list: + logging.info(f"Loading gigaspeech cuts subset: {subset}") + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts += cuts + if self.args.gigaspeech_subset == subset: + break + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get {self.args.libriheavy_subset} subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.libriheavy_subset}.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.wenetspeech_subset}.jsonl.gz" + ) + if self.args.on_the_fly_feats: + cuts_train = cuts_train.drop_features() + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + + cuts = load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz") + if self.args.on_the_fly_feats: + cuts_train = cuts.drop_features() + return cuts + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train-with-3-embeddings.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train-with-3-embeddings.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train-with-3-embeddings.jsonl.gz" + ) + return cuts + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/kd_datamodule3_shar.py b/egs/emilia/CLAP/spear/kd_datamodule3_shar.py new file mode 100644 index 0000000000..6b5bee5f60 --- /dev/null +++ b/egs/emilia/CLAP/spear/kd_datamodule3_shar.py @@ -0,0 +1,1723 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +import os +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + ReverbWithImpulseResponse, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from augmentations import BatchMixing +from dataset2_npy_cache import MultiTaskKDDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--num-mel-bins", + type=int, + default=128, + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--sync-buckets", + type=str2bool, + default=True, + ) + group.add_argument( + "--use-custom-duration-bins", + type=str2bool, + default=False, + ) + group.add_argument( + "--duration-bins", + type=str, + default="None" + ) + group.add_argument( + "--duration-bins-weights", + type=str, + default="None", + ) + group.add_argument( + "--merge-buckets", + type=str2bool, + default=False, + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-rir", + type=str2bool, + default=False, + help="When enabled, perform RIR on the cuts", + ) + + group.add_argument( + "--rir-cuts", + type=str, + default="data/rir/rir_cuts.jsonl.gz", + help="If None, use the default fast random RIR generator" + ) + + group.add_argument( + "--features-mask-size", + type=int, + default=27, + help="The maximum mask bins along the frequency axis in specaug" + ) + + group.add_argument( + "--frames-mask-size", + type=int, + default=100, + help="The maximum mask length along the time axis in specaug" + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--batch-mixing", + type=str2bool, + default=False, + ) + + group.add_argument( + "--batch-mixing-mode", + type=str, + default="batch", + choices=["batch", "musan"], + ) + + group.add_argument( + "--mixing-prob", + type=float, + default=0.5, + help="""The mixing probability, applicable to both musan and in-batch mixing. + In musan, it means the noise mixing prob. In batch mixing, it means the augmentation + prob, consisting of both in-batch mixing and noise mixing. + """ + ) + + group.add_argument( + "--p-noise", + type=float, + default=0.0, + help="The probability of mixing noise from non speech noise. Only applicable to in-batch mixing" + ) + + group.add_argument( + "--min-snr", + type=float, + default=10, + help="The minimum SNR used in noise mixing." + ) + + group.add_argument( + "--min-noise-snr", + type=float, + default=-5, + help="The minimum SNR used in noise mixing from non-speech noise. Only used in BatchMixing" + ) + + group.add_argument( + "--max-snr", + type=float, + default=20, + help="The minimum SNR used in noise mixing." + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + group.add_argument( + "--target-frame-rate", + type=int, + default=50, + help="The frame rate of the target" + ) + + # KD related + group.add_argument( + "--at-KD", + type=str2bool, + default=False, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-fisher", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxpopuli", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="M", + choices=["S", "M", "L"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + choices=["small", "medium", "large"] + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-emotion-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-emo", + type=int, + default=1, + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--use-music4all", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-music4all", + type=int, + default=1, + ) + + group.add_argument( + "--use-vggsound", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-vggsound", + type=int, + default=1, + ) + + group.add_argument( + "--use-bbceffect", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-freesound", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-mtg", + type=str2bool, + default=False, + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + # world_size = 1 + # rank = 0 + + transforms = [] + if self.args.enable_rir: + logging.info("Enable RIR") + if os.path.exists(self.args.rir_cuts): + logging.info("About to get RIR cuts") + rir_cuts = load_manifest_lazy(self.args.rir_cuts) + else: + logging.info("Use the fast random RIR generator as no RIR recordings are provided") + rir_cuts = None + transforms.append( + ReverbWithImpulseResponse(rir_recordings=rir_cuts, p=0.5) + ) + + if self.args.enable_musan: + assert not self.args.batch_mixing, "Do not use musan and in-batch mixing together!" + logging.info(f"Enable MUSAN with minimum SNR={self.args.min_snr}, max SNR={self.args.max_snr}, mixing prob: {self.args.mixing_prob}") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz").drop_features() + transforms.append( + CutMix( + cuts=cuts_musan, p=self.args.mixing_prob, snr=(self.args.min_snr, self.args.max_snr), + preserve_id=True, pad_to_longest=False + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.batch_mixing: + assert not self.args.enable_musan, "Do not use musan and in-batch mixing together!" + if self.args.p_noise > 0.0: + noise_cuts = load_manifest("data/musan/audioset_non_human.jsonl.gz").drop_features() + logging.info(f"Get the noise cuts for batch mixing as well") + else: + noise_cuts = None + t = BatchMixing( + min_snr=self.args.min_snr, + max_snr=self.args.max_snr, + p=self.args.mixing_prob, + min_noise_snr=self.args.min_noise_snr, + p_noise=self.args.p_noise, + noise_cuts=noise_cuts, + ) + transforms.append(t) + logging.info(f"Performing batch mixing: {t}") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=self.args.features_mask_size, + num_feature_masks=2, + frames_mask_size=self.args.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + + assert self.args.on_the_fly_feats + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + logging.info(f"Sync buckets: {self.args.sync_buckets}") + + if self.args.use_custom_duration_bins: + assert self.args.merge_buckets == False, "Cannot use merge buckets when using custom duration bins" + assert self.args.duration_bins != "None", "If use_custom_duration_bins, duration_bins should not be None" + duration_bins = list(map(float, self.args.duration_bins.split(","))) + if self.args.duration_bins_weights != "None": + duration_bins_weights = list(map(float, self.args.duration_bins_weights.split(","))) + assert len(duration_bins_weights) == len(duration_bins) + 1, "The length of duration_bins_weights should be len(duration_bins) + 1" + else: + duration_bins_weights = [1.0] * (len(duration_bins) + 1) + logging.info(f"Using custom duration bins: {duration_bins}, weights: {duration_bins_weights}") + else: + duration_bins = None + duration_bins_weights = None + + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 5000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + sync_buckets=self.args.sync_buckets, + duration_bins=duration_bins, + duration_bins_weights=duration_bins_weights, + merge_buckets=self.args.merge_buckets, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"Max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskKDDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/dev", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def fisher_cuts(self) -> CutSet: + logging.info("About to get Fisher cuts") + # part1: 1016 hrs, 1055801 cuts + # part2: 1025 hrs, 1057637 cuts + parts = ["part1", "part2"] + if self.args.use_shar: + all_cuts = [] + for part in parts: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/fisher/{part}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1016, 1025], + stop_early=False, + ) + else: + part1_cuts = load_manifest_lazy( + self.args.manifest_dir / "fisher_cuts_part1.jsonl.gz" + ) + part2_cuts = load_manifest_lazy( + self.args.manifest_dir / "fisher_cuts_part2.jsonl.gz" + ) + return part1_cuts + part2_cuts + + def voxpopuli_unlabelled_cuts(self) -> CutSet: + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/voxpopuli/en_v2/", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + libriheavy_list = ["small", "medium", "large"] + durations = [466, 4148, 42074] + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(libriheavy_list): + logging.info(f"Getting libriheavy subset {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/libriheavy/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(f"data/vq_whisper_turbo_zh_en_16_v2_numpy/libriheavy_cuts_{subset}.jsonl.gz") + + all_cuts.append(cuts) + if self.args.libriheavy_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ).drop_features() + return all_cuts + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def mls_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + if self.args.use_shar: + num_splits = 8 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/MLS/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ).resample(16000) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"MLS_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","common_voice_20200622","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone","speech_wav"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def weread_dataset_cuts(self): + logging.info("About to get weread dataset") + + + @cached_property + def dataset_duration_stats(self): + stats_file = f"data/stats/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"data/stats/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz") + return cuts.drop_features() + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get audioset eval cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/eval", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy(self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz") + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def music4all_cuts(self) -> CutSet: + logging.info("About to get music4all cuts") + if self.args.use_shar: + logging.info(f"Use share for music4all cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/music4all/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "music4all_cuts_all.jsonl.gz" + ) + + @lru_cache() + def vggsound_train_cuts(self) -> CutSet: + logging.info("About to get vgg sound training cuts") + if self.args.use_shar: + logging.info(f"Use shard for vggsound") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/vggsound/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "vggsound_cuts_train.jsonl.gz" + ) + + @lru_cache() + def vggsound_test_cuts(self) -> CutSet: + logging.info("About to get vgg sound test cuts") + if self.args.use_shar: + logging.info(f"Use shard for vggsound") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/vggsound/test", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "vggsound_cuts_test.jsonl.gz" + ) + + @lru_cache() + def mtg_cuts(self) -> CutSet: + # 1028645 cuts, 2811:31:17 hrs + logging.info("About to get MTG cuts") + if self.args.use_shar: + logging.info(f"Use shard for MTG cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/mtg_wav", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "mtg_wav_cuts_10s.jsonl.gz" + ) + + @lru_cache() + def bbc_soundeffect_train_cuts(self) -> CutSet: + logging.info("About to get BBC sound effect training cuts") + if self.args.use_shar: + logging.info(f"Use shard for BBC cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/bbc_soundeffect/train_10s", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "bbc_soundeffect_cuts_train_10s.jsonl.gz" + ) + + @lru_cache() + def bbc_soundeffect_test_cuts(self) -> CutSet: + logging.info("About to get BBC sound effect test cuts") + if self.args.use_shar: + logging.info(f"Use shard for BBC cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/bbc_soundeffect/test_10s", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "bbc_soundeffect_cuts_test_10s.jsonl.gz" + ) + + @lru_cache() + def freesound_train_cuts(self) -> CutSet: + logging.info("About to get freesound training cuts") + if self.args.use_shar: + logging.info(f"Use shard for freesound cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/freesound/train_10s", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "freesound_cuts_train_10s.jsonl.gz" + ) + + @lru_cache() + def freesound_test_cuts(self) -> CutSet: + logging.info("About to get freesound test cuts") + if self.args.use_shar: + logging.info(f"Use shard for freesound cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/freesound/test_10s", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "freesound_cuts_test_10s.jsonl.gz" + ) + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + @lru_cache() + def meld_train_cust(self) -> CutSet: + logging.info("About to get MELD training cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/MELD/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "meld_cuts_train.jsonl.gz" + ) + + @lru_cache() + def iemocap_cust(self) -> CutSet: + logging.info("About to get IEMOCAP cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/iemocap/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "iemocap_cuts_all.jsonl.gz" + ) + + @lru_cache() + def mead_cuts(self) -> CutSet: + logging.info("About to get MEAD cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/mead/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "mead_cuts_all.jsonl.gz" + ) + + @lru_cache() + def multi_emotion_cuts(self) -> CutSet: + logging.info("About to combine multiple emotion datasets") + iemocap_cuts = self.iemocap_cust() # 7 hrs, 5502 cuts + mead_cuts = self.mead_cuts() # 37 hrs, 31720 cuts + meld_cuts = self.meld_train_cust() # 8.5 hrs, 9045 cuts + return CutSet.mux( + *[iemocap_cuts, mead_cuts, meld_cuts], + stop_early=False, + weights=[5502, 31720, 9045] + ) + + @lru_cache() + def msp_podcast_train_cust(self) -> CutSet: + logging.info("About to get msp podcast training cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/msp_podcast/Train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "msp_podcast_cuts_Train.jsonl.gz" + ) + + @lru_cache() + def msp_podcast_dev_cust(self) -> CutSet: + logging.info("About to get msp podcast development cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/msp_podcast/Development", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "msp_podcast_cuts_Development.jsonl.gz" + ) + +def _test(): + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + + + +def _test_bucketing_sampler(): + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + args.use_shar = False + args.on_the_fly_feats = True + args.max_duration = 600 + args.audioset_subset = "full" + args.manifest_dir = Path("data/vq_dasheng_large_layer_-1_normalize_0_cb_8") + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + + cuts_audioset = mtl_datamodule.audioset_cuts() + cuts_audioset = cuts_audioset.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) + cuts_music4all = mtl_datamodule.music4all_cuts() + cuts_music4all = cuts_music4all.map(partial(_add_dummy_embeddings_and_taskIDs, 3)) + + cuts_train = CutSet.mux( + *[cuts_audioset, cuts_music4all], + weights=[1904746,109269], + stop_early=False + ) + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=cuts_train, + sampling_weight=[100] + ) + count = { + "audioset": 0, + "music": 0, + } + + for batch_idx, batch in enumerate(train_dl): + # import pdb; pdb.set_trace() + task_ids = batch["task_ids"] + num_as_cuts = (task_ids == 2).sum() + num_music_cuts = (task_ids == 3).sum() + count["audioset"] += num_as_cuts + count["music"] += num_music_cuts + print(count) + if batch_idx > 1000: + break + + + +if __name__=="__main__": + _test_bucketing_sampler() \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/kd_datamodule3_shar_multi_teacher.py b/egs/emilia/CLAP/spear/kd_datamodule3_shar_multi_teacher.py new file mode 100644 index 0000000000..fffff84054 --- /dev/null +++ b/egs/emilia/CLAP/spear/kd_datamodule3_shar_multi_teacher.py @@ -0,0 +1,1200 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from dataset_multi_speech_mvq import MultiTaskKDDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--speech-shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--audio-shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--num-mel-bins", + type=int, + default=128, + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--features-mask-size", + type=int, + default=27, + help="The maximum mask bins along the frequency axis in specaug" + ) + + group.add_argument( + "--frames-mask-size", + type=int, + default=100, + help="The maximum mask length along the time axis in specaug" + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # KD related + group.add_argument( + "--at-KD", + type=str2bool, + default=True, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="M", + choices=["S", "M", "L"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + choices=["small", "medium", "large"] + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + # world_size = 1 + # rank = 0 + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=self.args.features_mask_size, + num_feature_masks=2, + frames_mask_size=self.args.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = MultiTaskKDDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=True, + sv_KD=True + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"Max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskKDDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/gigaspeech/dev", + shuffle_shards=False, + ) + else: + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + if self.args.use_shar: + medium_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/libriheavy/medium", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + if self.args.libriheavy_subset == "medium": + return medium_cuts + else: + assert self.args.libriheavy_subset == "large" + large_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/libriheavy/large", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = [medium_cuts, large_cuts] + return CutSet.mux( + *cuts, + weights=[1, 9], + stop_early=False, + ) + + else: + return load_manifest_lazy( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.libriheavy_subset}.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.speech_shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + logging.info("Get wenetspeech dev cuts from shar") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def mls_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + if self.args.use_shar: + num_splits = 8 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.speech_shar_dir)}/MLS/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ).resample(16000) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.speech_shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","common_voice_20200622","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone","speech_wav"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.speech_shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @cached_property + def dataset_duration_stats(self): + stats_file = f"{self.args.shar_dir}/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"{self.args.shar_dir}/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/audioset/eval", + shuffle_shards=False, + ) + return cuts + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/kd_datamodule3_shar_multi_teacher2.py b/egs/emilia/CLAP/spear/kd_datamodule3_shar_multi_teacher2.py new file mode 100644 index 0000000000..cea5e59011 --- /dev/null +++ b/egs/emilia/CLAP/spear/kd_datamodule3_shar_multi_teacher2.py @@ -0,0 +1,1258 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + ReverbWithImpulseResponse, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from dataset_multi_speech_mvq import MultiTaskKDDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--en-shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with english cuts.", + ) + group.add_argument( + "--zh-shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with chinese cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--num-mel-bins", + type=int, + default=128, + ) + group.add_argument( + "--simple-sampler", + type=str2bool, + default=False, + help="When enabled, use the simple cut sampler.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames). To enable buckting sampler" + "in zip sampler, set simple-sampler to False", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-rir", + type=str2bool, + default=False, + help="When enabled, perform RIR on the cuts", + ) + + group.add_argument( + "--rir-cuts", + type=str, + default="data/rir/rir_cuts.jsonl.gz", + help="If None, use the default fast random RIR generator" + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # KD related + group.add_argument( + "--at-KD", + type=str2bool, + default=True, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="M", + choices=["S", "M", "L"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + choices=["small", "medium", "large"] + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-weread", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + # world_size = 1 + # rank = 0 + + transforms = [] + if self.args.enable_rir: + logging.info("Enable MUSAN") + if self.args.rir_cuts is not None: + logging.info("About to get RIR cuts") + rir_cuts = load_manifest_lazy("data/rir/rir_cuts.jsonl.gz") + else: + logging.info("Use the fast random RIR generator as no RIR recordings are provided") + rir_cuts = None + transforms.append( + ReverbWithImpulseResponse(rir_recordings=rir_cuts, p=0.5) + ) + + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/noise/noise_all.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = MultiTaskKDDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=True, + sv_KD=True + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"Max duration for {name}: {md}") + if "audioset" not in name: # speech cuts + if self.args.simple_sampler: + sampler = SimpleCutSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 2000, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + # there is no need to use bucketing sampler for audioset data + # the distribution of cut duration is even + sampler = SimpleCutSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskKDDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + import lhotse + lhotse.set_caching_enabled(True) + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + if self.args.use_shar: + medium_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/libriheavy/medium", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + if self.args.libriheavy_subset == "medium": + return medium_cuts + else: + assert self.args.libriheavy_subset == "large" + large_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/libriheavy/large", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = [medium_cuts, large_cuts] + return CutSet.mux( + *cuts, + weights=[1, 9], + stop_early=False, + ) + + else: + return load_manifest_lazy( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.libriheavy_subset}.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.zh_shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + logging.info("Get wenetspeech dev cuts from shar") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.zh_shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def mls_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + if self.args.use_shar: + num_splits = 8 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.en_shar_dir)}/MLS/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ).resample(16000) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.en_shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + datasets += ["speech_wav"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.zh_shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def weread_dataset_cuts(self): + logging.info("About to get weread dataset") + num_splits = 10 + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for split in range(num_splits): + logging.info(f"Loading weread split {split}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.zh_shar_dir}/weread/split_{split}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(9000) + cuts_len.append(300) + all_cuts = CutSet.mux( + *all_cuts, + weights=[1]*len(all_cuts), # each split is the same duration + stop_early=False + ) + all_duration = num_splits * 6000 + all_len = num_splits * 2999930 + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of weread data. ") + return all_cuts, all_duration, all_len + + + @cached_property + def dataset_duration_stats(self): + stats_file = f"data/stats/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"data/stats/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts.drop_features() + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.en_shar_dir)}/audioset/eval", + shuffle_shards=False, + ) + return cuts + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/kd_datamodule3_shar_speech_audio_multi_teacher.py b/egs/emilia/CLAP/spear/kd_datamodule3_shar_speech_audio_multi_teacher.py new file mode 100644 index 0000000000..0befe1b653 --- /dev/null +++ b/egs/emilia/CLAP/spear/kd_datamodule3_shar_speech_audio_multi_teacher.py @@ -0,0 +1,1774 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from dataset_speech_audio_mvq import MultiTaskKDDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--speech-shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--audio-shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--max-cuts", + type=int, + default=2000, + help="Maximum number of cuts per batch; Useful to adjust this when" + "seeing CUDA OOM in zipsampler", + ) + group.add_argument( + "--batch-duration-factor", + type=int, + default=4, + help="""Used to filter shorter cuts when batching. The ZipSampler can sometime + produce very large batch, this is a double safety measure to prevent the model + from OOM error. This is evaluated as an upperlimit: batch_duration_factor * max_duration + """, + ) + group.add_argument( + "--num-mel-bins", + type=int, + default=128, + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--sync-buckets", + type=str2bool, + default=True, + ) + group.add_argument( + "--use-custom-duration-bins", + type=str2bool, + default=False, + ) + group.add_argument( + "--duration-bins", + type=str, + default="None" + ) + group.add_argument( + "--duration-bins-weights", + type=str, + default="None", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--features-mask-size", + type=int, + default=27, + help="The maximum mask bins along the frequency axis in specaug" + ) + + group.add_argument( + "--frames-mask-size", + type=int, + default=100, + help="The maximum mask length along the time axis in specaug" + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--mixing-prob", + type=float, + default=0.5, + help="The mixing probability, applicable to both musan and in-batch mixing" + ) + + group.add_argument( + "--min-snr", + type=float, + default=10, + help="The minimum SNR used in noise mixing." + ) + + group.add_argument( + "--max-snr", + type=float, + default=20, + help="The minimum SNR used in noise mixing." + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # KD related + group.add_argument( + "--at-KD", + type=str2bool, + default=True, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + group.add_argument( + "--speech-target-frame-rate", + type=int, + default=50, + help="The speech target's frame rate in Hz" + ) + + group.add_argument( + "--audio-target-frame-rate", + type=int, + default=25, + help="The audio target's frame rate in Hz" + ) + + group.add_argument( + "--num-cb-speech", + type=int, + default=16, + help="Number of codebooks for speech MVQ" + ) + + group.add_argument( + "--num-cb-audio", + type=int, + default=8, + help="Number of codebooks for audio MVQ" + ) + + # multi task dataset related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--repeat-gigaspeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-fisher", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxpopuli", + type=str2bool, + default=False, + ) + group.add_argument( + "--voxpopuli-subset", + type=str, + default="en_v2", + ) + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="M", + choices=["S", "M", "L"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + choices=["small", "medium", "large"] + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-emotion-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-emo", + type=int, + default=1, + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--use-music4all", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-music4all", + type=int, + default=1, + ) + + group.add_argument( + "--use-vggsound", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-vggsound", + type=int, + default=1, + ) + + group.add_argument( + "--use-bbceffect", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-freesound", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-mtg", + type=str2bool, + default=False, + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + + transforms = [] + if self.args.enable_musan: + logging.info(f"Enable MUSAN with minimum SNR={self.args.min_snr}, mixing prob: {self.args.mixing_prob}") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz").drop_features() + transforms.append( + CutMix( + cuts=cuts_musan, p=0.5, snr=(self.args.min_snr, self.args.max_snr), preserve_id=True, pad_to_longest=False + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=self.args.features_mask_size, + num_feature_masks=2, + frames_mask_size=self.args.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, + ) + ) + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}, " + f"frames_mask_size: {self.args.frames_mask_size}, " + f"features_mask_size: {self.args.features_mask_size}" + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + assert self.args.on_the_fly_feats + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + speech_target_frame_rate=self.args.speech_target_frame_rate, + num_cb_speech=self.args.num_cb_speech, + audio_target_frame_rate=self.args.audio_target_frame_rate, + num_cb_audio=self.args.num_cb_audio, + batch_duration_threshold=self.args.max_duration * self.args.batch_duration_factor, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + logging.info(f"Sync buckets: {self.args.sync_buckets}") + if self.args.use_custom_duration_bins: + assert self.args.duration_bins != "None", "If use_custom_duration_bins, duration_bins should not be None" + duration_bins = list(map(float, self.args.duration_bins.split(","))) + if self.args.duration_bins_weights != "None": + duration_bins_weights = list(map(float, self.args.duration_bins_weights.split(","))) + assert len(duration_bins_weights) == len(duration_bins) + 1, "The length of duration_bins_weights should be len(duration_bins) + 1" + else: + duration_bins_weights = [1.0] * (len(duration_bins) + 1) + logging.info(f"Using custom duration bins: {duration_bins}, weights: {duration_bins_weights}") + else: + duration_bins = None + duration_bins_weights = None + # duration_bins = [2.0, 5.0, 9.9, 10.1, 15, 22] + # duration_bins_weights = [1,1,1,2.5,1,1,1] + # logging.info(f"Using weighted duration bins: {duration_bins}, weights: {duration_bins_weights}") + # logging.info("Ignoring pre-defined num buckets because duration bins is given.") + import pdb; pdb.set_trace() + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 50000, + shuffle_buffer_size=self.args.num_buckets * 50000, + drop_last=self.args.drop_last, + sync_buckets=self.args.sync_buckets, + duration_bins=duration_bins, + duration_bins_weights=duration_bins_weights, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"Max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + max_cuts=self.args.max_cuts, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + speech_target_frame_rate=self.args.speech_target_frame_rate, + num_cb_speech=self.args.num_cb_speech, + audio_target_frame_rate=self.args.audio_target_frame_rate, + num_cb_audio=self.args.num_cb_audio, + ) + else: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + speech_target_frame_rate=self.args.speech_target_frame_rate, + num_cb_speech=self.args.num_cb_speech, + audio_target_frame_rate=self.args.audio_target_frame_rate, + num_cb_audio=self.args.num_cb_audio, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskKDDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + speech_target_frame_rate=self.args.speech_target_frame_rate, + num_cb_speech=self.args.num_cb_speech, + audio_target_frame_rate=self.args.audio_target_frame_rate, + num_cb_audio=self.args.num_cb_audio, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/gigaspeech/dev", + shuffle_shards=False, + ) + else: + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def fisher_cuts(self) -> CutSet: + logging.info("About to get Fisher cuts") + # part1: 1016 hrs, 1055801 cuts + # part2: 1025 hrs, 1057637 cuts + parts = ["part1", "part2"] + if self.args.use_shar: + all_cuts = [] + for part in parts: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/fisher/{part}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1016, 1025], + stop_early=False, + ) + else: + part1_cuts = load_manifest_lazy( + self.args.manifest_dir / "fisher_cuts_part1.jsonl.gz" + ) + part2_cuts = load_manifest_lazy( + self.args.manifest_dir / "fisher_cuts_part2.jsonl.gz" + ) + return part1_cuts + part2_cuts + + @lru_cache() + def voxpopuli_asr_train_cuts(self) -> CutSet: + # languages = ["en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"] + VOX_POPULI_LANGUAGES = { + "en": 514, "de": 270, "fr": 202, "es": 153, "pl": 100, "it": 80, "ro": 77, "hu": 55, "cs": 55, + "nl": 48, "fi": 22, "hr": 18.5, "sk": 31, "sl": 6.5, "et": 2, "lt": 1.5, + } # total 1636 hrs, 526497 cuts + + all_cuts = [] + duration_weights = [] + for lang, dur in VOX_POPULI_LANGUAGES.items(): + logging.info(f"Loading voxpopuli {lang}") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/voxpopuli/{lang}/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / f"voxpopuli-asr-{lang}_cuts_train.jsonl.gz" + ) + all_cuts.append(cuts) + duration_weights.append(dur) + + all_cuts = CutSet.mux( + *all_cuts, + weights=duration_weights, + stop_early=False, + ) + all_cuts = all_cuts.map(fix_supervisions) + all_cuts = all_cuts.filter(filter_supervisions_start) + return all_cuts + + @lru_cache() + def voxpopuli_asr_dev_cuts(self) -> CutSet: + # languages = ["en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"] + VOX_POPULI_LANGUAGES = { + "en": 514, "de": 270, "fr": 202, "es": 153, "pl": 100, "it": 80, "ro": 77, "hu": 55, "cs": 55, + "nl": 48, "fi": 22, "hr": 18.5, "sk": 31, "sl": 6.5, "et": 2, + } + + all_cuts = [] + duration_weights = [] + for lang, dur in VOX_POPULI_LANGUAGES.items(): + logging.info(f"Loading voxpopuli {lang}") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/voxpopuli/{lang}/dev", + shuffle_shards=False, + ) + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / f"voxpopuli-asr-{lang}_cuts_dev.jsonl.gz" + ) + all_cuts.append(cuts) + duration_weights.append(dur) + + all_cuts = CutSet.mux( + *all_cuts, + weights=[1.0]*len(all_cuts), + stop_early=False, + ) + all_cuts = all_cuts.map(fix_supervisions) + all_cuts = all_cuts.filter(filter_supervisions_start) + return all_cuts + + def voxpopuli_unlabelled_cuts(self) -> CutSet: + if self.args.use_shar: + logging.info(f"Loading the unlabelled voxpopuli data: {self.args.voxpopuli_subset}") + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/voxpopuli/{self.args.voxpopuli_subset}/", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"voxpopuli_cuts_{self.args.voxpopuli_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + libriheavy_list = ["small", "medium", "large"] + durations = [466, 4148, 42074] + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(libriheavy_list): + logging.info(f"Getting libriheavy subset {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/libriheavy/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(f"data/vq_whisper_turbo_zh_en_16_v2_numpy/libriheavy_cuts_{subset}.jsonl.gz") + + all_cuts.append(cuts) + if self.args.libriheavy_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ).drop_features() + return all_cuts + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.speech_shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + logging.info("Get wenetspeech dev cuts from shar") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def mls_train_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + LANGUAGES={ + "german": 1966.5, "dutch": 1554, "french": 1077, "polish": 104, "spanish": 918, "italian": 247, "portuguese": 161 + } + all_cuts = [] + durations = [] + + for lang, dur in LANGUAGES.items(): + if self.args.use_shar: + split_dir = f"{str(self.args.speech_shar_dir)}/mls/{lang}/train" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / f"mls-asr-{lang}_train.jsonl.gz" + ).resample(16000) + all_cuts.append(cuts) + durations.append(dur) + return CutSet.mux( + *all_cuts, + weights=durations, + stop_early=False, + ) + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.speech_shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","common_voice_20200622","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone","speech_wav"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.speech_shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @cached_property + def dataset_duration_stats(self): + stats_file = f"{self.args.shar_dir}/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"{self.args.shar_dir}/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{self.args.audio_shar_dir}/audioset/eval", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def vggsound_train_cuts(self) -> CutSet: + logging.info("About to get vgg sound training cuts") + if self.args.use_shar: + logging.info(f"Use shard for vggsound") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/vggsound/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "vggsound_cuts_train.jsonl.gz" + ) + + @lru_cache() + def vggsound_test_cuts(self) -> CutSet: + logging.info("About to get vgg sound test cuts") + if self.args.use_shar: + logging.info(f"Use shard for vggsound") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/vggsound/test", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "vggsound_cuts_test.jsonl.gz" + ) + + @lru_cache() + def mtg_cuts(self) -> CutSet: + # 1028645 cuts, 2811:31:17 hrs + logging.info("About to get MTG cuts") + if self.args.use_shar: + logging.info(f"Use shard for MTG cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/mtg_wav", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "mtg_wav_cuts_10s.jsonl.gz" + ) + + @lru_cache() + def music4all_cuts(self) -> CutSet: + logging.info("About to get music4all cuts") + if self.args.use_shar: + logging.info(f"Use share for music4all cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/music4all/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "music4all_cuts_all.jsonl.gz" + ) + + @lru_cache() + def bbc_soundeffect_train_cuts(self) -> CutSet: + logging.info("About to get BBC sound effect training cuts") + if self.args.use_shar: + logging.info(f"Use shard for BBC cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/bbc_soundeffect/train_10s", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "bbc_soundeffect_cuts_train_10s.jsonl.gz" + ) + + @lru_cache() + def bbc_soundeffect_test_cuts(self) -> CutSet: + logging.info("About to get BBC sound effect test cuts") + if self.args.use_shar: + logging.info(f"Use shard for BBC cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/bbc_soundeffect/test_10s", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "bbc_soundeffect_cuts_test_10s.jsonl.gz" + ) + + @lru_cache() + def freesound_train_cuts(self) -> CutSet: + logging.info("About to get freesound training cuts") + if self.args.use_shar: + logging.info(f"Use shard for freesound cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/freesound/train_10s", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "freesound_cuts_train_10s.jsonl.gz" + ) + + @lru_cache() + def freesound_test_cuts(self) -> CutSet: + logging.info("About to get freesound test cuts") + if self.args.use_shar: + logging.info(f"Use shard for freesound cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.audio_shar_dir)}/freesound/test_10s", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "freesound_cuts_test_10s.jsonl.gz" + ) + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + @lru_cache() + def meld_train_cust(self) -> CutSet: + logging.info("About to get MELD training cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/MELD/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "meld_cuts_train.jsonl.gz" + ) + + @lru_cache() + def iemocap_cust(self) -> CutSet: + logging.info("About to get IEMOCAP cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/iemocap/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "iemocap_cuts_all.jsonl.gz" + ) + + @lru_cache() + def mead_cuts(self) -> CutSet: + logging.info("About to get MEAD cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/mead/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "mead_cuts_all.jsonl.gz" + ) + + @lru_cache() + def multi_emotion_cuts(self) -> CutSet: + logging.info("About to combine multiple emotion datasets") + iemocap_cuts = self.iemocap_cust() # 7 hrs, 5502 cuts + mead_cuts = self.mead_cuts() # 37 hrs, 31720 cuts + meld_cuts = self.meld_train_cust() # 8.5 hrs, 9045 cuts + return CutSet.mux( + *[iemocap_cuts, mead_cuts, meld_cuts], + stop_early=False, + weights=[5502, 31720, 9045] + ) + + @lru_cache() + def msp_podcast_train_cust(self) -> CutSet: + logging.info("About to get msp podcast training cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/msp_podcast/Train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "msp_podcast_cuts_Train.jsonl.gz" + ) + + @lru_cache() + def msp_podcast_dev_cust(self) -> CutSet: + logging.info("About to get msp podcast development cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.speech_shar_dir)}/msp_podcast/Development", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "msp_podcast_cuts_Development.jsonl.gz" + ) +def fix_supervisions(cut): + supervision = cut.supervisions[0] + cut.supervisions = [supervision] + return cut + +def filter_supervisions_start(c): + if c.supervisions[0].start != 0.0: + return False + return True + + +if __name__=="__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + args.gigaspeech_subset = "xl" + args.libriheavy_subset = "large" + args.audioset_subset = "full" + args.use_shar = True + args.speech_shar_dir = "data-shar/data-shar-hubert-large-layer-21-normalize-cb16-hdf5" + args.audio_shar_dir = "data-shar/data-shar-dasheng-as-cb8" + args.num_buckets = 20 + args.on_the_fly_feats = 1 + args.sync_buckets = False + args.num_workers = 0 + args.at_KD = False + args.max_duration = 400 + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + + import pdb; pdb.set_trace() + libriheavy_cuts = mtl_datamodule.libriheavy_train_cuts() + gigaspeech_cuts = mtl_datamodule.gigaspeech_train_cuts() + asr_cuts = [libriheavy_cuts, gigaspeech_cuts] + asr_cuts = CutSet.mux( + *asr_cuts, + weights=[10093746, 8611516], + stop_early=False, + ) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + + def change_codebook_indexes(c): + c.audio_codebook_indexes = c.codebook_indexes + del c.codebook_indexes + return c + + audio_cuts = mtl_datamodule.audioset_cuts().repeat(4) + audio_cuts = audio_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + audio_cuts = audio_cuts.map(change_codebook_indexes) + + train_cuts = [asr_cuts, audio_cuts] + train_cuts = CutSet.mux( + *train_cuts, + weights=[2,1], + stop_early=False, + ) + + import pdb; pdb.set_trace() + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + ) + num_epochs = 3 + import pdb; pdb.set_trace() + for epoch in range(1, num_epochs+1): + # train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + duration1, duration2 = 0,0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + cuts = batch["supervisions"]["cut"] + cuts_1 = [c for c in cuts if c.task_id == 1] + cuts_2 = [c for c in cuts if c.task_id == 2] + duration1 += sum([c.duration for c in cuts_1]) + duration2 += sum([c.duration for c in cuts_2]) + logging.info(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)}, {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 200: + break + # if batch_idx == 0: + # print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Sample stats: {num1}, {num2}; Duration stats: {duration1}, {duration2}") + # print(f"Number of cuts from task1: {num1}") + # print(f"Number of cuts from task2: {num2}") + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/kd_datamodule3_shar_token_mixing.py b/egs/emilia/CLAP/spear/kd_datamodule3_shar_token_mixing.py new file mode 100644 index 0000000000..289aad0eca --- /dev/null +++ b/egs/emilia/CLAP/spear/kd_datamodule3_shar_token_mixing.py @@ -0,0 +1,1637 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +import os +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + ReverbWithImpulseResponse, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from augmentations import BatchMixing +from dataset2_batch_mixing import MultiTaskKDDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--num-mel-bins", + type=int, + default=128, + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-rir", + type=str2bool, + default=False, + help="When enabled, perform RIR on the cuts", + ) + + group.add_argument( + "--rir-cuts", + type=str, + default="data/rir/rir_cuts.jsonl.gz", + help="If None, use the default fast random RIR generator" + ) + + group.add_argument( + "--features-mask-size", + type=int, + default=27, + help="The maximum mask bins along the frequency axis in specaug" + ) + + group.add_argument( + "--frames-mask-size", + type=int, + default=100, + help="The maximum mask length along the time axis in specaug" + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=False, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--token-mixing", + type=str2bool, + default=True, + help="Perform token mixing, it does batch mixing internally. And it also interpolate the tokens", + ) + + group.add_argument( + "--batch-mixing", + type=str2bool, + default=False, + ) + + group.add_argument( + "--mixing-prob", + type=float, + default=0.2, + help="The mixing probability, applicable to both musan and in-batch mixing" + ) + + group.add_argument( + "--min-snr", + type=float, + default=10, + help="The minimum SNR used in noise mixing." + ) + + group.add_argument( + "--max-snr", + type=float, + default=10, + help="The minimum SNR used in noise mixing." + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + group.add_argument( + "--target-frame-rate", + type=int, + default=50, + help="The frame rate of the target" + ) + + # KD related + group.add_argument( + "--at-KD", + type=str2bool, + default=False, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-fisher", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxpopuli", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="M", + choices=["S", "M", "L"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + choices=["small", "medium", "large"] + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-emotion-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-emo", + type=int, + default=1, + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--use-music4all", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-music4all", + type=int, + default=1, + ) + + group.add_argument( + "--use-vggsound", + type=str2bool, + default=False, + ) + + group.add_argument( + "--repeat-vggsound", + type=int, + default=1, + ) + + group.add_argument( + "--use-bbceffect", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-freesound", + type=str2bool, + default=False, + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + # world_size = 1 + # rank = 0 + + transforms = [] + if self.args.enable_rir: + logging.info("Enable RIR") + if os.path.exists(self.args.rir_cuts): + logging.info("About to get RIR cuts") + rir_cuts = load_manifest_lazy(self.args.rir_cuts) + else: + logging.info("Use the fast random RIR generator as no RIR recordings are provided") + rir_cuts = None + transforms.append( + ReverbWithImpulseResponse(rir_recordings=rir_cuts, p=0.5) + ) + + if self.args.enable_musan: + assert not self.args.batch_mixing, "Do not use musan and in-batch mixing together!" + assert not self.args.token_mixing + logging.info(f"Enable MUSAN with minimum SNR={self.args.min_snr}") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz").drop_features() + transforms.append( + CutMix( + cuts=cuts_musan, p=0.5, snr=(self.args.min_snr, self.args.max_snr), preserve_id=True, pad_to_longest=False + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.token_mixing: + assert self.args.batch_mixing + + if self.args.batch_mixing: + assert not self.args.enable_musan, "Do not use musan and in-batch mixing together!" + t = BatchMixing( + min_snr=self.args.min_snr, + max_snr=self.args.max_snr, + p=self.args.mixing_prob, + ) + transforms.append(t) + logging.info(f"Performing batch mixing: {t}") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=self.args.features_mask_size, + num_feature_masks=2, + frames_mask_size=self.args.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + + assert self.args.on_the_fly_feats + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + token_mixing=self.args.token_mixing, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"Max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 1500, + shuffle_buffer_size=self.args.num_buckets * 1500, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskKDDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskKDDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + target_frame_rate=self.args.target_frame_rate, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/dev", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def fisher_cuts(self) -> CutSet: + logging.info("About to get Fisher cuts") + # part1: 1016 hrs, 1055801 cuts + # part2: 1025 hrs, 1057637 cuts + parts = ["part1", "part2"] + if self.args.use_shar: + all_cuts = [] + for part in parts: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/fisher/{part}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1016, 1025], + stop_early=False, + ) + else: + part1_cuts = load_manifest_lazy( + self.args.manifest_dir / "fisher_cuts_part1.jsonl.gz" + ) + part2_cuts = load_manifest_lazy( + self.args.manifest_dir / "fisher_cuts_part2.jsonl.gz" + ) + return part1_cuts + part2_cuts + + def voxpopuli_unlabelled_cuts(self) -> CutSet: + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/voxpopuli/en_v2/", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + libriheavy_list = ["small", "medium", "large"] + durations = [466, 4148, 42074] + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(libriheavy_list): + logging.info(f"Getting libriheavy subset {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/libriheavy/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(f"data/vq_whisper_turbo_zh_en_16_v2_numpy/libriheavy_cuts_{subset}.jsonl.gz") + + all_cuts.append(cuts) + if self.args.libriheavy_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ).drop_features() + return all_cuts + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def mls_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + if self.args.use_shar: + num_splits = 8 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/MLS/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ).resample(16000) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"MLS_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","common_voice_20200622","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone","speech_wav"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def weread_dataset_cuts(self): + logging.info("About to get weread dataset") + + + @cached_property + def dataset_duration_stats(self): + stats_file = f"data/stats/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"data/stats/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts.drop_features() + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get audioset eval cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/eval", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def music4all_cuts(self) -> CutSet: + logging.info("About to get music4all cuts") + if self.args.use_shar: + logging.info(f"Use share for music4all cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/music4all/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "music4all_cuts_all.jsonl.gz" + ) + + @lru_cache() + def vggsound_train_cuts(self) -> CutSet: + logging.info("About to get vgg sound training cuts") + if self.args.use_shar: + logging.info(f"Use shard for vggsound") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/vggsound/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "vggsound_cuts_train.jsonl.gz" + ) + + @lru_cache() + def vggsound_test_cuts(self) -> CutSet: + logging.info("About to get vgg sound test cuts") + if self.args.use_shar: + logging.info(f"Use shard for vggsound") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/vggsound/test", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "vggsound_cuts_test.jsonl.gz" + ) + + @lru_cache() + def bbc_soundeffect_train_cuts(self) -> CutSet: + logging.info("About to get BBC sound effect training cuts") + if self.args.use_shar: + logging.info(f"Use shard for BBC cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/bbc_soundeffect/train_10s", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "bbc_soundeffect_cuts_train_10s.jsonl.gz" + ) + + @lru_cache() + def bbc_soundeffect_test_cuts(self) -> CutSet: + logging.info("About to get BBC sound effect test cuts") + if self.args.use_shar: + logging.info(f"Use shard for BBC cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/bbc_soundeffect/test_10s", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "bbc_soundeffect_cuts_test_10s.jsonl.gz" + ) + + @lru_cache() + def freesound_train_cuts(self) -> CutSet: + logging.info("About to get freesound training cuts") + if self.args.use_shar: + logging.info(f"Use shard for freesound cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/freesound/train_10s", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "freesound_cuts_train_10s.jsonl.gz" + ) + + @lru_cache() + def freesound_test_cuts(self) -> CutSet: + logging.info("About to get freesound test cuts") + if self.args.use_shar: + logging.info(f"Use shard for freesound cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/freesound/test_10s", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "freesound_cuts_test_10s.jsonl.gz" + ) + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + @lru_cache() + def meld_train_cust(self) -> CutSet: + logging.info("About to get MELD training cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/MELD/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "meld_cuts_train.jsonl.gz" + ) + + @lru_cache() + def iemocap_cust(self) -> CutSet: + logging.info("About to get IEMOCAP cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/iemocap/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "iemocap_cuts_all.jsonl.gz" + ) + + @lru_cache() + def mead_cuts(self) -> CutSet: + logging.info("About to get MEAD cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/mead/all", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "mead_cuts_all.jsonl.gz" + ) + + @lru_cache() + def multi_emotion_cuts(self) -> CutSet: + logging.info("About to combine multiple emotion datasets") + iemocap_cuts = self.iemocap_cust() # 7 hrs, 5502 cuts + mead_cuts = self.mead_cuts() # 37 hrs, 31720 cuts + meld_cuts = self.meld_train_cust() # 8.5 hrs, 9045 cuts + return CutSet.mux( + *[iemocap_cuts, mead_cuts, meld_cuts], + stop_early=False, + weights=[5502, 31720, 9045] + ) + + @lru_cache() + def msp_podcast_train_cust(self) -> CutSet: + logging.info("About to get msp podcast training cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/msp_podcast/Train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "msp_podcast_cuts_Train.jsonl.gz" + ) + + @lru_cache() + def msp_podcast_dev_cust(self) -> CutSet: + logging.info("About to get msp podcast development cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/msp_podcast/Development", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "msp_podcast_cuts_Development.jsonl.gz" + ) + +def _test(): + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + + + +def _test_bucketing_sampler(): + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + args.use_shar = False + args.on_the_fly_feats = True + args.max_duration = 600 + args.audioset_subset = "full" + args.manifest_dir = Path("data/vq_dasheng_large_layer_-1_normalize_0_cb_8") + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + + cuts_audioset = mtl_datamodule.audioset_cuts() + cuts_audioset = cuts_audioset.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) + cuts_music4all = mtl_datamodule.music4all_cuts() + cuts_music4all = cuts_music4all.map(partial(_add_dummy_embeddings_and_taskIDs, 3)) + + cuts_train = CutSet.mux( + *[cuts_audioset, cuts_music4all], + weights=[1904746,109269], + stop_early=False + ) + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=cuts_train, + sampling_weight=[100] + ) + count = { + "audioset": 0, + "music": 0, + } + + for batch_idx, batch in enumerate(train_dl): + # import pdb; pdb.set_trace() + task_ids = batch["task_ids"] + num_as_cuts = (task_ids == 2).sum() + num_music_cuts = (task_ids == 3).sum() + count["audioset"] += num_as_cuts + count["music"] += num_music_cuts + print(count) + if batch_idx > 1000: + break + + + +if __name__=="__main__": + _test_bucketing_sampler() \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/label_smoothing.py b/egs/emilia/CLAP/spear/label_smoothing.py new file mode 100644 index 0000000000..adce2bd04c --- /dev/null +++ b/egs/emilia/CLAP/spear/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/mask_mae.py b/egs/emilia/CLAP/spear/mask_mae.py new file mode 100644 index 0000000000..894393bf0a --- /dev/null +++ b/egs/emilia/CLAP/spear/mask_mae.py @@ -0,0 +1,446 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random +from typing import List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from encoder_interface import EncoderInterface +from lhotse.dataset import SpecAugment + +from icefall.utils import AttributeDict, make_pad_mask + + +class AudioPretrainingModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: nn.Module, + fbank_dim: int = 80, + encoder_dim: int = 384, + encoder_input_dim: int = 192, + decoder_dim: int = 384, + decoder_input_dim: int = 192, + noise_scale: float = 0.1, + mask_prob: float = 0.65, + mask_length: int = 10, + mask_selection: str = "static", + mask_other: float = 0.0, + ): + """An audio pretraining model + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + encoder_dim: + Dimension of the encoder. + noise_scale: + The scale of the gaussia noise. + """ + super().__init__() + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + self.fbank_dim = fbank_dim + + self.decoder = decoder + self.decoder_input_dim = decoder_input_dim + self.decoder_dim = decoder_dim + + # decoder embed + self.decoder_embed = nn.Linear( + encoder_dim, decoder_input_dim, bias=True, + ) + # decoder pred to 4 * fbank dim (we concatenate every 4 frames) + self.decoder_pred = nn.Linear( + decoder_dim, fbank_dim * 4, bias=True, + ) + + # mask embeddings + self.mask_emb = nn.Parameter(torch.FloatTensor(fbank_dim).uniform_()) + self.decoder_mask_emb = nn.Parameter(torch.FloatTensor(encoder_dim).normal_()) + + self.mask_prob = mask_prob + self.mask_length = mask_length + self.mask_selection = mask_selection + self.mask_other = mask_other + + self.noise_scale = noise_scale + + def forward_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + target: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + target: + The reconstruction target + Returns: + Return the binary crossentropy loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + N, T, C = x.shape + + padding_mask = make_pad_mask(x_lens) + + # apply masking to the fbank features + x, mask_indices = self.apply_mask_facebook( + x.clone(), + padding_mask=padding_mask + ) # (N,T,C), (N,T) + + x, x_lens = self.encoder_embed(x, x_lens) # (N,T,C) + src_key_padding_mask = make_pad_mask(x_lens) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) # (T,N,C) + + # Normalize encoder features + normalize_factor = (encoder_out ** 2).mean(dim=-1, keepdim=True).sqrt() + encoder_out = encoder_out / normalize_factor + + if self.training: + # add noise to the encoder_out + noise = torch.randn_like(encoder_out, device=encoder_out.device) * self.noise_scale + encoder_out += noise + + # replace the masked encoder_out with a mask_emb + decoder_mask_indices = nn.functional.max_pool1d(mask_indices, 4) + assert decoder_mask_indices.size(1) >= encoder_out.size(0) + if decoder_mask_indices.size(1) > encoder_out.size(0): + decoder_mask_indices = decoder_mask_indices[:, :encoder_out.size(0)] + + decoder_mask_indices = decoder_mask_indices.bool().T + encoder_out[decoder_mask_indices] = self.decoder_mask_emb + + # perform the reconstruction + decoder_src_key_padding_mask = make_pad_mask(encoder_out_lens) + decoder_in = self.decoder_embed(encoder_out) # project to decoder_dim + decoder_out, decoder_out_lens = self.decoder(decoder_in, encoder_out_lens, decoder_src_key_padding_mask) + + decoder_out = self.decoder_pred(decoder_out) + decoder_out = decoder_out.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + # compute the reconstruction loss + assert target.size(1) >= 4 * decoder_out.size(1), (target.size(1), decoder_out.size(1)) + target = target[:, : 4 * decoder_out.size(1), :].reshape(N, -1, 4, self.fbank_dim) + target = target.reshape(N, -1, 4 * self.fbank_dim) + l2_loss = nn.functional.mse_loss( + decoder_out, + target, + reduction="none" + ) # (N, T, C) + + # mask the loss on the padding positions + l2_loss.masked_fill_(decoder_src_key_padding_mask.unsqueeze(-1), 0.0) + + # only compute reconstruction loss on masked frames + mask_indices = nn.functional.max_pool1d(mask_indices.float(), 4) + assert mask_indices.size(1) >= decoder_src_key_padding_mask.size(1) + if mask_indices.size(1) > decoder_src_key_padding_mask.size(1): + mask_indices = mask_indices[:, :decoder_src_key_padding_mask.size(1)] + l2_loss *= mask_indices.unsqueeze(-1) + + # normalize the mse loss by the fbank dimension + l2_loss = l2_loss.sum() / self.fbank_dim + + return l2_loss + + def apply_mask_facebook( + self, + x: torch.Tensor, + padding_mask, + ): + # this function is modified from fairseq: https://github.com/facebookresearch/fairseq/blob/bedb259bf34a9fc22073c13a1cee23192fa70ef3/fairseq/models/wav2vec/wav2vec2.py#L429 + # The masked indices have value 1 + B, T, C = x.shape + + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + mask_type=self.mask_selection, + mask_other=self.mask_other, + min_masks=2, + no_overlap=False, # False + min_space=1, # 1 + require_same_masks=False, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + mask_indices = mask_indices.float() + if random.random() > 0.97: + logging.info(f"A proportion of {mask_indices.sum()/mask_indices.numel():.2f} frames are masked") + else: + mask_indices = None + + return x, mask_indices + + +def index_put(tensor, indices, value): + tensor[indices] = value + return tensor + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError("this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + diff --git a/egs/emilia/CLAP/spear/model.py b/egs/emilia/CLAP/spear/model.py new file mode 100644 index 0000000000..37464d477f --- /dev/null +++ b/egs/emilia/CLAP/spear/model.py @@ -0,0 +1,327 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn + +from icefall.utils import make_pad_mask + + +class AsrKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + num_codebooks: int=8, + distillation_layer: int=9, + distillation_delta: int=0, + teacher_frame_ratio: int = 2, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.teacher_frame_ratio = teacher_frame_ratio + self.distillation_delta = distillation_delta + from multi_quantization.prediction import JointCodebookLoss + + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + ) + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + codebook_indexes: + Codebook indexes of teacher embeddings + + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + # align the encoder features with the codebook indexes + if codebook_indexes.shape[1] != encoder_out.shape[1]: + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio + ) + if self.distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] + encoder_out = encoder_out[:, self.distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + codebook_loss = self.codebook_loss_net( + encoder_out.float(), codebook_indexes + ) + + return codebook_loss + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ] + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes + +class AudioTaggingModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int = 384, + num_events: int = 527, + ): + """An audio tagging model + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + encoder_dim: + Dimension of the encoder. + num_event: + The number of classes. + """ + super().__init__() + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.classifier = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) + + # for multi-class classification + self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum") + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + encoder_out, encoder_out_lens, all_hidden_states = self.encoder(x, x_lens) + + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens, all_hidden_states + + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + target: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + target: + The ground truth label of audio events, could be many hot + Returns: + Return the binary crossentropy loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + # Compute encoder outputs + encoder_out, encoder_out_lens, _ = self.forward_encoder(x, x_lens) + + # Forward the speaker module + logits = self.forward_audio_tagging( + encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) # (N, num_classes) + + loss = self.criterion(logits, target) + + return loss + + def forward_audio_tagging(self, encoder_out, encoder_out_lens): + """ + Args: + encoder_out: + A 3-D tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + A 3-D tensor of shape (N, num_classes). + """ + logits = self.classifier(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) # mask the padding frames + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( + logits + ) # normalize the logits + + return logits \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/model_asr.py b/egs/emilia/CLAP/spear/model_asr.py new file mode 100644 index 0000000000..f186cb2f50 --- /dev/null +++ b/egs/emilia/CLAP/spear/model_asr.py @@ -0,0 +1,837 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import torch.nn.functional as F + +from icefall.utils import add_sos, make_pad_mask + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + normalize_fbank: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + normalize_fbank: + If true, the input fbank features is normalized to zero mean and unit variance + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = nn.Linear(encoder_dim, vocab_size) + + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + self.use_attention_decoder = use_attention_decoder + self.attention_decoder = attention_decoder + + self.normalize_fbank = normalize_fbank + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool=False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + with torch.set_grad_enabled((not freeze_encoder) and self.training): + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="sum", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + freeze_encoder: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder( + x, + x_lens, + freeze_encoder=freeze_encoder + ) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + # Compute CTC loss + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss + + +class MultiTaskModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + encoder_downsample: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + num_events: int = 527, + normalize_fbank: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + normalize_fbank: + If true, the input fbank features is normalized to zero mean and unit variance utterance-wise + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_downsample = encoder_downsample + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = nn.Linear(encoder_dim, vocab_size) + + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + self.use_attention_decoder = use_attention_decoder + self.attention_decoder = attention_decoder + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + self.normalize_fbank = normalize_fbank + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool=False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + + # normalise fbank (utterance level) + if self.normalize_fbank: + x = self._normalize_fbank(x, x_lens) + + with torch.set_grad_enabled((not freeze_encoder) and self.training): + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + # if an extra downsample is placed after the encoder + if self.encoder_downsample is not None: + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_downsample(encoder_out) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out_lens = (encoder_out_lens + 1 ) // 2 + + return encoder_out, encoder_out_lens + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="none", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="none", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="none", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + at_targets: torch.Tensor = None, + freeze_encoder: bool = False, + skip_asr: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + device = x.device + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder( + x, + x_lens, + freeze_encoder=freeze_encoder + ) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer and not skip_asr: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc and not skip_asr: + # Compute CTC loss + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + else: + at_loss = None + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, at_loss + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + def forward_audio_tagging_linear_softmax( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + frame_logits = self.audio_tagging_proj(encoder_out) + + # --- Linear Softmax Pooling (Corrected Version) 开始 --- + + # 2. 将Logits转换为帧级别的概率 (激活值) + # (N, T, num_classes) + frame_probabilities = torch.sigmoid(frame_logits) + + # 3. 处理padding,将填充部分的概率设为0 + padding_mask = make_pad_mask(encoder_out_lens, max_len=frame_probabilities.size(1)) # (N, T) + expanded_padding_mask = padding_mask.unsqueeze(-1).expand_as(frame_probabilities) + frame_probabilities[expanded_padding_mask] = 0.0 + + # 4. 计算线性归一化权重 (不使用exp) + # 沿时间维度求和,用于归一化 + # 添加一个小的epsilon防止除以零 + sum_over_time = torch.sum(frame_probabilities, dim=1, keepdim=True) + 1e-7 + + # 权重就是归一化后的概率 + # (N, T, num_classes) + linear_weights = frame_probabilities / sum_over_time + + # 5. 使用线性权重对原始的帧级别概率进行加权求和 + # (N, T, num_classes) * (N, T, num_classes) -> (N, T, num_classes) + # 然后在时间维度上求和 -> (N, num_classes) + clip_probabilities = torch.sum(linear_weights * frame_probabilities, dim=1) + + # --- Linear Softmax Pooling 结束 --- + + if return_logits: # 实际上返回的是概率 + return clip_probabilities + + # 关键修改:由于我们现在得到的是概率(probabilities), + # 损失函数需要使用 F.binary_cross_entropy,而不是 F.binary_cross_entropy_with_logits + at_loss = F.binary_cross_entropy(clip_probabilities, target, reduction="none") + + return at_loss + + @staticmethod + def _normalize_fbank(x: torch.Tensor, x_lens: torch.Tensor, eps: float=1e-9): + """ + x: (B, T, D) fbank 特征,已 padding 到同一 T + x_lens: (B,) 每条样本的有效帧数 (int) + """ + device = x.device + B, T, D = x.shape + + # mask: (B, T, 1) + mask = torch.arange(T, device=device).unsqueeze(0) < x_lens.unsqueeze(1) + mask = mask.unsqueeze(-1) # (B, T, 1), bool + + lengths = x_lens.view(B, 1, 1).to(x.dtype) # (B, 1, 1) + + # 均值 + sum_feats = (x * mask).sum(dim=1, keepdim=True) # (B, 1, D) + mean = sum_feats / lengths + + # 方差 + sum_sq = ((x - mean) * mask).pow(2).sum(dim=1, keepdim=True) + std = torch.sqrt(sum_sq / lengths + eps) + + # 归一化 + x_norm = (x - mean) / (std + eps) + # set masking positions to value 0 + x_norm = x_norm * mask + + return x_norm \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/model_multi_kd.py b/egs/emilia/CLAP/spear/model_multi_kd.py new file mode 100644 index 0000000000..21012eea94 --- /dev/null +++ b/egs/emilia/CLAP/spear/model_multi_kd.py @@ -0,0 +1,414 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from multi_quantization.prediction import JointCodebookLoss + +from icefall.utils import make_pad_mask + + +class MultiKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + num_codebooks: int=8, + distillation_layer: int=9, + distillation_delta: int=0, + teacher_frame_ratio: int = 2, + interpolate_teacher: bool = False, + num_events: int = 527 + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.num_codebooks= num_codebooks + self.teacher_frame_ratio = teacher_frame_ratio + self.interpolate_teacher = interpolate_teacher + self.distillation_delta = distillation_delta + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + reduction="none", + ) + else: + self.codebook_loss_net = None + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + codebook_indexes: torch.Tensor = None, + at_targets: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + codebook_indexes: + Codebook indexes of teacher embeddings + + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert codebook_indexes is not None or at_targets is not None + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + if codebook_indexes is not None and self.codebook_loss_net is not None: + codebook_loss = self.forward_codebook_loss(encoder_out, encoder_out_lens, codebook_indexes) + else: + codebook_loss = None + + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + else: + at_loss = None + + return codebook_loss, at_loss + + def forward_codebook_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + ): + # align the encoder features with the codebook indexes + if self.interpolate_teacher: + codebook_indexes = self.interpolate_codebook_indexes( + encoder_out, codebook_indexes + ) + else: + if codebook_indexes.shape[1] != encoder_out.shape[1]: + # align the codebook indexes to the frame rate of the student encoder out + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio + ) + + # the delta is associated with the frame-rate of the encoder + # so a bigger delta maybe necessary for 50Hz student encoder + if self.distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] + encoder_out = encoder_out[:, self.distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + N,T,_ = encoder_out.shape + codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes) + codebook_loss = codebook_loss.reshape(N,T,-1) + num_cb = codebook_loss.size(-1) + # normalize the loss by the number of codebooks + codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb + + return codebook_loss + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + @staticmethod + def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): + # This function addresses the case where the teacher has a lower frame rate + # than the student model + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T) + codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) + codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C) + + assert codebook_indexes.shape[1] == middle_layer_output.shape[1] + return codebook_indexes + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ], + dim=1, + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes + +class AudioTaggingModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int = 384, + num_events: int = 527, + ): + """An audio tagging model + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + encoder_dim: + Dimension of the encoder. + num_event: + The number of classes. + """ + super().__init__() + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.classifier = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) + + # for multi-class classification + self.criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + target: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + target: + The ground truth label of audio events, could be many hot + Returns: + Return the binary crossentropy loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + # Forward the speaker module + logits = self.forward_audio_tagging( + encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) # (N, num_classes) + + loss = self.criterion(logits, target) + + return loss + + def forward_audio_tagging(self, encoder_out, encoder_out_lens): + """ + Args: + encoder_out: + A 3-D tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + A 3-D tensor of shape (N, num_classes). + """ + logits = self.classifier(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) # mask the padding frames + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( + logits + ) # normalize the logits + + return logits \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/model_multi_kd_co_training.py b/egs/emilia/CLAP/spear/model_multi_kd_co_training.py new file mode 100644 index 0000000000..47b5156d82 --- /dev/null +++ b/egs/emilia/CLAP/spear/model_multi_kd_co_training.py @@ -0,0 +1,477 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from lhotse.dataset import SpecAugment +from multi_quantization2 import JointCodebookLoss + +from icefall.utils import make_pad_mask + + +class MultiKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + num_codebooks: int=8, + distillation_layer: int=9, + distillation_delta: int=0, + teacher_frame_ratio: int = 2, + interpolate_teacher: bool = False, + num_events: int = 527, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.num_codebooks= num_codebooks + self.teacher_frame_ratio = teacher_frame_ratio + self.interpolate_teacher = interpolate_teacher + self.distillation_delta = distillation_delta + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + checkpoint=False, + reduction="none", + ) + else: + self.codebook_loss_net = None + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + codebook_indexes: torch.Tensor = None, + at_targets: torch.Tensor = None, + use_co_training: bool = True, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + codebook_indexes: + Codebook indexes of teacher embeddings (N,T,C) + at_targets: + The audio tagging target (N, num_events) + + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert codebook_indexes is not None or at_targets is not None + + # create a another copy of augmented input + if use_co_training: + if use_spec_aug: + assert spec_augment is not None + # Apply time warping before input duplicating + assert supervision_segments is not None + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + + x_lens = x_lens.repeat(2) + if codebook_indexes is not None: + codebook_indexes = codebook_indexes.repeat(2,1,1) + if at_targets is not None: + at_targets = at_targets.repeat(2,1) + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + if codebook_indexes is not None and self.codebook_loss_net is not None: + codebook_loss, logprobs = self.forward_codebook_loss(encoder_out, encoder_out_lens, codebook_indexes) + if use_co_training: + codebook_loss *= 0.5 + else: + codebook_loss = None + logprobs = None + + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + if use_co_training: + at_loss *= 0.5 + else: + at_loss = None + + if use_co_training: + co_training_loss = self.forward_co_training_loss(logprobs, encoder_out_lens) + co_training_loss *= 0.5 + else: + co_training_loss = None + + return codebook_loss, at_loss, co_training_loss + + def forward_co_training_loss( + self, + logits: torch.Tensor, + encoder_out_lens: torch.Tensor, + ): + logits = logits.log_softmax(dim=-1) + exchanged_targets = logits.detach().chunk(2, dim=0) # split into two halves + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + + co_training_loss = nn.functional.kl_div( + input=logits, + target=exchanged_targets, + reduction="none", + log_target=True, # we do the log_softmax + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1).unsqueeze(-1) + + N = encoder_out_lens.shape[0] + num_cb = co_training_loss.shape[1] + co_training_loss = co_training_loss.reshape(N, -1, *co_training_loss.shape[-2:]) + co_training_loss = co_training_loss.masked_fill(length_mask, 0.0) + co_training_loss = co_training_loss.sum(dim=(1,2,3))/ num_cb + + return co_training_loss + + + def forward_codebook_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + ): + # align the encoder features with the codebook indexes + if self.interpolate_teacher: + codebook_indexes = self.interpolate_codebook_indexes( + encoder_out, codebook_indexes + ) + else: + if codebook_indexes.shape[1] != encoder_out.shape[1]: + # align the codebook indexes to the frame rate of the student encoder out + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio + ) + + # the delta is associated with the frame-rate of the encoder + # so a bigger delta maybe necessary for 50Hz student encoder + if self.distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] + encoder_out = encoder_out[:, self.distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + N,T,_ = encoder_out.shape + codebook_loss, logprobs = self.codebook_loss_net(encoder_out.float(), codebook_indexes) + codebook_loss = codebook_loss.reshape(N,T,-1) + num_cb = codebook_loss.size(-1) * (2 / self.teacher_frame_ratio) # TODO: ugly way to keep the value comparable, need to change + # normalize the loss by the number of codebooks + codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb + + return codebook_loss, logprobs + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + @staticmethod + def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): + # This function addresses the case where the teacher has a lower frame rate + # than the student model + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T) + codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) + codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C) + + assert codebook_indexes.shape[1] == middle_layer_output.shape[1] + return codebook_indexes + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ] + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes + +class AudioTaggingModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int = 384, + num_events: int = 527, + ): + """An audio tagging model + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + encoder_dim: + Dimension of the encoder. + num_event: + The number of classes. + """ + super().__init__() + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.classifier = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) + + # for multi-class classification + self.criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + target: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + target: + The ground truth label of audio events, could be many hot + Returns: + Return the binary crossentropy loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + # Forward the speaker module + logits = self.forward_audio_tagging( + encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) # (N, num_classes) + + loss = self.criterion(logits, target) + + return loss + + def forward_audio_tagging(self, encoder_out, encoder_out_lens): + """ + Args: + encoder_out: + A 3-D tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + A 3-D tensor of shape (N, num_classes). + """ + logits = self.classifier(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) # mask the padding frames + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( + logits + ) # normalize the logits + + return logits \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/model_multi_kd_mae.py b/egs/emilia/CLAP/spear/model_multi_kd_mae.py new file mode 100644 index 0000000000..00cc4476c4 --- /dev/null +++ b/egs/emilia/CLAP/spear/model_multi_kd_mae.py @@ -0,0 +1,396 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from multi_quantization.prediction import JointCodebookLoss + +from icefall.utils import make_pad_mask + + +class MAELoss(torch.nn.Module): + def __init__(self, normalize_mode: str): + super().__init__() + # If True, normalise the target by frame + assert normalize_mode in ["frame", "sample", "batch"] + self.normalize_mode = normalize_mode + + def forward(self, pred: torch.Tensor, target: torch.Tensor,) -> torch.Tensor: + if self.normalize_mode == "frame": # adopted by Dasheng + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + elif self.normalize_mode == "sample": + mean = target.mean(dim=(1, 2), keepdim=True) # per sample + var = target.var(dim=(1, 2), keepdim=True) + target = (target - mean) / (var + 1e-6)**.5 + elif self.norm_by_frame == 'batch': + mean = target.mean() + var = target.var() + target = (target - mean) / (var + 1.e-6)**.5 + + # compute the MSE loss + loss = (pred - target)**2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + return loss + +class MultiKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + decoder: nn.Module, + decoder_dim: int, + num_codebooks: int=8, + distillation_layer: int=9, + distillation_delta: int=0, + teacher_frame_ratio: int = 2, + interpolate_teacher: bool = False, + n_mels: int = 128, + num_events: int = 527, + mae_loss_norm: str = "sample", + mae_downsample_factor: int = 4, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + self.decoder = decoder + self.decoder_dim = decoder_dim + + self.fbank_dim = n_mels + self.mae_downsample_factor = mae_downsample_factor + self.decoder_embed = nn.Linear(encoder_dim, decoder_dim) # projecting encoder_out to decoder dim + self.decoder_pred = nn.Linear(decoder_dim, n_mels * mae_downsample_factor) # we are predicting 4 fbank frames per decoder frame + + # mvq distillation + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.num_codebooks= num_codebooks + self.teacher_frame_ratio = teacher_frame_ratio + self.interpolate_teacher = interpolate_teacher + self.distillation_delta = distillation_delta + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + reduction="none", + ) + else: + self.codebook_loss_net = None + + self.mae_loss_norm = mae_loss_norm + self.mae_loss = MAELoss(mae_loss_norm) + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward_decoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ): + """Compute the output of decoder + + Args: + x (torch.Tensor): (N,T,C) + x_lens (torch.Tensor): (N,) + """ + x = self.decoder_embed(x) # N,T,C + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1,0,2) # T,N,C + + decoder_out, decoder_out_lens = self.decoder(x, x_lens, src_key_padding_mask) + decoder_out = decoder_out.permute(1,0,2) # N,T,C + + return decoder_out, decoder_out_lens + + def forward_mae_loss( + self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + # compute the MAE loss + decoder_out, decoder_out_lens = self.forward_decoder( + encoder_out, encoder_out_lens, + ) + pred = self.decoder_pred(decoder_out) # map to 4 * fbank dim + N,T,_ = pred.shape + pred = pred.reshape(N, -1, self.fbank_dim) + + assert pred.shape[2] == target.shape[2] + target = self.truncate_target(pred, target) + loss = self.mae_loss(pred, target) # (N,T) + + padding_mask = ~ make_pad_mask(decoder_out_lens * self.mae_downsample_factor) + loss = loss * padding_mask + loss = loss.sum(dim=1) # (N,) + return loss + + @staticmethod + def truncate_target(pred: torch.Tensor, target: torch.Tensor): + # truncate the target on both sides for better alignment + # Only consider the cases where the target is longer + # pred: (N,T,C) + # target: (N,T,C) + assert target.shape[1] >= pred.shape[1] + if target.shape[1] == pred.shape[1]: + return target + diff = target.shape[1] - pred.shape[1] + if diff == 1: + target = target[:, :-1, :] # throw the last frame + else: + left = diff // 2 + right = diff - left + target = target[:, left:-right, :] # trim on both sides + assert target.shape[1] == pred.shape[1] + + return target + + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + fbank_target: torch.Tensor, + codebook_indexes: torch.Tensor = None, + at_targets: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + fbank_target: + The original fbank features + codebook_indexes: + Codebook indexes of teacher embeddings + + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert codebook_indexes is not None or at_targets is not None + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + # Compute codebook loss + if codebook_indexes is not None and self.codebook_loss_net is not None: + codebook_loss = self.forward_codebook_loss(encoder_out, encoder_out_lens, codebook_indexes) + else: + codebook_loss = None + + # Compute audio tagging loss (if needed) + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + else: + at_loss = None + + mae_loss = self.forward_mae_loss(encoder_out, encoder_out_lens, fbank_target) + + return codebook_loss, at_loss, mae_loss + + def forward_codebook_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + ): + # align the encoder features with the codebook indexes + if self.interpolate_teacher: + codebook_indexes = self.interpolate_codebook_indexes( + encoder_out, codebook_indexes + ) + else: + if codebook_indexes.shape[1] != encoder_out.shape[1]: + # align the codebook indexes to the frame rate of the student encoder out + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio + ) + + # the delta is associated with the frame-rate of the encoder + # so a bigger delta maybe necessary for 50Hz student encoder + if self.distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] + encoder_out = encoder_out[:, self.distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + N,T,_ = encoder_out.shape + codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes) + codebook_loss = codebook_loss.reshape(N,T,-1) + num_cb = codebook_loss.size(-1) + # normalize the loss by the number of codebooks + codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb + + return codebook_loss + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + @staticmethod + def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): + # This function addresses the case where the teacher has a lower frame rate + # than the student model + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T) + codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) + codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C) + + assert codebook_indexes.shape[1] == middle_layer_output.shape[1] + return codebook_indexes + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ] + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes diff --git a/egs/emilia/CLAP/spear/model_multi_kd_multi_teacher.py b/egs/emilia/CLAP/spear/model_multi_kd_multi_teacher.py new file mode 100644 index 0000000000..febb67b07a --- /dev/null +++ b/egs/emilia/CLAP/spear/model_multi_kd_multi_teacher.py @@ -0,0 +1,462 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from multi_quantization.prediction import JointCodebookLoss + +from model_multi_kd_w2v2_mask import compute_mask_indices, compute_mask_indices_block, index_put +from icefall.utils import make_pad_mask + + +class MultiKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + num_codebooks: list[int]=None, + distillation_layer: list[int]=None, + distillation_delta: list[int]=None, + teacher_frame_ratio: list[int]=None, + interpolate_teacher: bool = False, + num_events: int = 527, + n_mels: int = 128, + mask_mode: str = "w2v2", + mask_prob: float = 0.65, + mask_length: int = 10, + mask_selection: str = "static", + mask_other: float = 0.0, + min_masks: int = 2, + mask_channel_prob: float = 0.0, + mask_channel_length: int = 10, + mask_channel_selection: str = "static", + mask_channel_other: float = 0.0, + loss_only_mask: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + num_codebooks: + A list of integers, how many codebooks for each target + mask_mode: + The masking mode. + w2v2: the wav2vec2 style of masking, allows overlap + custom: no overlap, therefore bigger masking ratio + mask_prob: + The probability of selecting choosing one frame as the start index + mask_length: + The length of each mask + mask_selection: + How to determine the length of the mask, see ``compute_mask_indices'' + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.num_codebooks= num_codebooks + self.teacher_frame_ratio = teacher_frame_ratio + self.interpolate_teacher = interpolate_teacher + self.distillation_delta = distillation_delta + + self.codebook_loss_heads = nn.ModuleList() + for cb, frame_ratio in zip(num_codebooks, teacher_frame_ratio): + if cb > 0: + codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=cb * frame_ratio, + is_joint=False, + reduction="none", + ) + else: + codebook_loss_net = None + self.codebook_loss_heads.append(codebook_loss_net) + + if len(self.codebook_loss_heads) == 0: + self.codebook_loss_heads = None + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + # masking related + assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}" + self.mask_mode = mask_mode + + self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_()) + self.mask_prob = mask_prob + self.mask_length = mask_length + self.mask_selection = mask_selection + self.mask_other = mask_other + self.min_masks = min_masks + + self.mask_channel_prob = mask_channel_prob + self.mask_channel_length = mask_channel_length + self.mask_channel_selection = mask_channel_selection + self.mask_channel_other = mask_channel_other + + self.loss_only_mask = loss_only_mask + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + codebook_indexes: list[torch.Tensor] = None, + at_targets: torch.Tensor = None, + mask: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + codebook_indexes: + Codebook indexes of teacher embeddings + mask: + If we perform w2v2 style of masking over the fbank frames + + Returns: + Return the codebook loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert codebook_indexes is not None or at_targets is not None + + # apply masking + if self.training and mask: + padding_mask = make_pad_mask(x_lens) + + # apply masking to the fbank features + x, mask_indices = self.apply_mask( + x.clone(), + padding_mask=padding_mask + ) # (N,T,C), (N,T) + else: + mask_indices = None + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + cb_losses = [] + if self.codebook_loss_heads is not None: + for i, cb_loss_net in enumerate(self.codebook_loss_heads): + cb_indexes = codebook_indexes[i] + if cb_indexes is not None and cb_loss_net is not None: + codebook_loss = self.forward_codebook_loss( + encoder_out, + encoder_out_lens, + cb_indexes, + cb_loss_net=cb_loss_net, + teacher_frame_ratio=self.teacher_frame_ratio[i], + distillation_delta=self.distillation_delta[i], + reduction="none" + ) + if self.loss_only_mask and mask_indices is not None: + # downsample the mask + mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5 + assert mask_indices.size(1) >= codebook_loss.size(1) + mask_indices = mask_indices[:, :codebook_loss.size(1)].float() + codebook_loss = codebook_loss * mask_indices + codebook_loss = codebook_loss.sum(dim=1) # (B,) + else: + codebook_loss = 0.0 + cb_losses.append(codebook_loss) + + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + else: + at_loss = None + + return *cb_losses, at_loss + + def forward_codebook_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + cb_loss_net: torch.nn.Module, + teacher_frame_ratio: int, + distillation_delta: int, + reduction: str = "sum", + ): + # align the encoder features with the codebook indexes + if self.interpolate_teacher: + codebook_indexes = self.interpolate_codebook_indexes( + encoder_out, codebook_indexes + ) + else: + if codebook_indexes.shape[1] != encoder_out.shape[1]: + # align the codebook indexes to the frame rate of the student encoder out + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=teacher_frame_ratio + ) + + # the delta is associated with the frame-rate of the student encoder + # so a bigger delta maybe necessary for 50Hz student encoder + if distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-distillation_delta, :] + encoder_out = encoder_out[:, distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + # compute the loss + N,T,_ = encoder_out.shape + codebook_loss = cb_loss_net(encoder_out.float(), codebook_indexes) + codebook_loss = codebook_loss.reshape(N,T,-1) + num_cb = codebook_loss.size(-1) # this is the equivalent number of codebooks + + # normalize the loss by the number of codebooks + if reduction == "sum": + codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb # (B,) + elif reduction == "none": + codebook_loss = codebook_loss.sum(dim=2) / num_cb # (B,T) + else: + raise NotImplementedError() + + return codebook_loss + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + def apply_mask( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply mask according to the mask_mode, return the masked features and the masked positions + + Args: + x (torch.Tensor): The input fbank features + padding_mask (torch.Tensor, optional): The padding mask + + Returns: + The masked fbank feature and the masked_indices, with masked positions as 1 + """ + # apply mask to the fbank features, two modes applicable + if self.mask_mode == "w2v2": + x, masked_indices = self.apply_mask_w2v2(x, padding_mask) + elif self.mask_mode == "block": + x, masked_indices = self.apply_mask_block(x, padding_mask) + else: + raise NotImplementedError() + + if random.random() > 0.97: + logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked") + return x, masked_indices + + def apply_mask_block( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ): + B,T,C = x.shape + assert self.mask_prob > 0.0 + + mask_indices = compute_mask_indices_block( + shape=(B,T), + padding_mask=padding_mask, + mask_prob=self.mask_prob, + mask_length=self.mask_length, + min_masks=self.min_masks, + ).to(x.device) + + x = index_put(x, mask_indices.bool(), self.mask_emb) + + return x, mask_indices + + def apply_mask_w2v2( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ): + # this function is modified from fairseq: https://github.com/facebookresearch/fairseq/blob/bedb259bf34a9fc22073c13a1cee23192fa70ef3/fairseq/models/wav2vec/wav2vec2.py#L429 + # The masked indices have value 1 + B, T, C = x.shape + + # we mask channel first, then mask timestamps + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=False, + min_space=1, + require_same_masks=False, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + if random.random() > 0.98: + logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked") + x[mask_channel_indices] = 0 + + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + mask_type=self.mask_selection, + mask_other=self.mask_other, + min_masks=2, # fixed + no_overlap=False, # False + min_space=1, # 1 + require_same_masks=False, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + mask_indices = mask_indices.float() + else: + mask_indices = None + + return x, mask_indices + + @staticmethod + def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): + # This function addresses the case where the teacher has a lower frame rate + # than the student model + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T) + codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) + codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C) + + assert codebook_indexes.shape[1] == middle_layer_output.shape[1] + return codebook_indexes + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ] + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes diff --git a/egs/emilia/CLAP/spear/model_multi_kd_w2v2_mask.py b/egs/emilia/CLAP/spear/model_multi_kd_w2v2_mask.py new file mode 100644 index 0000000000..6c3f9ac2c9 --- /dev/null +++ b/egs/emilia/CLAP/spear/model_multi_kd_w2v2_mask.py @@ -0,0 +1,815 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# Copyright 2025 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional, Tuple +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from multi_quantization.prediction import JointCodebookLoss + +from icefall.utils import make_pad_mask + + +class MultiKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + num_codebooks: int=8, + distillation_layer: int=9, + distillation_delta: int=0, + teacher_frame_ratio: int = 2, + interpolate_teacher: bool = False, + n_mels: int = 128, + num_events: int = 527, + mask_mode: str = "w2v2", + mask_prob: float = 0.65, + mask_length: int = 10, + mask_selection: str = "static", + mask_other: float = 0.0, + min_masks: int = 2, + mask_channel_prob: float = 0.0, + mask_channel_length: int = 10, + mask_channel_selection: str = "static", + mask_channel_other: float = 0.0, + loss_only_mask: bool = False, + normalize_fbank: bool = False, + ): + """A model that performs MVQ KD pre-training . + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + num_codebooks: + The number of codebooks used in the target + distillation_layer: + Use which layer to do MVQ pre-training + distillation_delta: + How many frames to delay the alignment between the model and the target frames. + Should be zero for non-streaming models, and a positive number for streaming models + teacher_frame_ratio: + The frame rate ratio between the target and the model output + mask_mode: + The masking mode. + w2v2: the wav2vec2 style of masking, allows overlap + custom: no overlap, therefore bigger masking ratio + mask_prob: + The probability of selecting choosing one frame as the start index + mask_length: + The length of each mask + mask_selection: + How to determine the length of the mask, see ``compute_mask_indices'' + normalize_fbank: + If true, the input fbank features is normalized to zero mean and unit variance + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.num_codebooks= num_codebooks + self.teacher_frame_ratio = teacher_frame_ratio + self.interpolate_teacher = interpolate_teacher + self.distillation_delta = distillation_delta + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + reduction="none", + ) + else: + self.codebook_loss_net = None + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + # masking related + assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}" + self.mask_mode = mask_mode + + self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_()) + self.mask_prob = mask_prob + self.mask_length = mask_length + self.mask_selection = mask_selection + self.mask_other = mask_other + self.min_masks = min_masks + + self.mask_channel_prob = mask_channel_prob + self.mask_channel_length = mask_channel_length + self.mask_channel_selection = mask_channel_selection + self.mask_channel_other = mask_channel_other + + self.loss_only_mask = loss_only_mask + self.normalize_fbank = normalize_fbank + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # normalise fbank (utterance level) + if self.normalize_fbank: + x = self._normalize_fbank(x, x_lens) + + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + @staticmethod + def _normalize_fbank(x: torch.Tensor, x_lens: torch.Tensor, eps: float=1e-9): + """ + x: (B, T, D) fbank 特征,已 padding 到同一 T + x_lens: (B,) 每条样本的有效帧数 (int) + """ + device = x.device + B, T, D = x.shape + + # mask: (B, T, 1) + mask = torch.arange(T, device=device).unsqueeze(0) < x_lens.unsqueeze(1) + mask = mask.unsqueeze(-1) # (B, T, 1), bool + + lengths = x_lens.view(B, 1, 1).to(x.dtype) # (B, 1, 1) + + # 均值 + sum_feats = (x * mask).sum(dim=1, keepdim=True) # (B, 1, D) + mean = sum_feats / lengths + + # 方差 + sum_sq = ((x - mean) * mask).pow(2).sum(dim=1, keepdim=True) + std = torch.sqrt(sum_sq / lengths + eps) + + # 归一化 + x_norm = (x - mean) / (std + eps) + # set masking positions to value 0 + x_norm = x_norm * mask + + return x_norm + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + codebook_indexes: torch.Tensor = None, + at_targets: torch.Tensor = None, + mask: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + codebook_indexes: + Codebook indexes of teacher embeddings + mask: + If we perform w2v2 style of masking over the fbank frames + + Returns: + Return the codebook loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert codebook_indexes is not None or at_targets is not None + + # apply masking + if self.training and mask: + padding_mask = make_pad_mask(x_lens) + + # apply masking to the fbank features + x, mask_indices = self.apply_mask( + x.clone(), + padding_mask=padding_mask + ) # (N,T,C), (N,T) + else: + mask_indices = None + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + if codebook_indexes is not None and self.codebook_loss_net is not None: + codebook_loss = self.forward_codebook_loss( + encoder_out, encoder_out_lens, codebook_indexes, reduction="none" + ) + if self.loss_only_mask and mask_indices is not None: + # downsample the mask + mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5 + assert mask_indices.size(1) >= codebook_loss.size(1) + mask_indices = mask_indices[:, :codebook_loss.size(1)].float() + codebook_loss = codebook_loss * mask_indices + codebook_loss = codebook_loss.sum(dim=1) # (B,) + else: + codebook_loss = None + + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + else: + at_loss = None + + return codebook_loss, at_loss + + def forward_codebook_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + reduction: str = "sum", + ): + # align the encoder features with the codebook indexes + if self.interpolate_teacher: + codebook_indexes = self.interpolate_codebook_indexes( + encoder_out, codebook_indexes + ) + else: + if codebook_indexes.shape[1] != encoder_out.shape[1]: + # align the codebook indexes to the frame rate of the student encoder out + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio + ) + + # the delta is associated with the frame-rate of the encoder + # so a bigger delta maybe necessary for 50Hz student encoder + if self.distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] + encoder_out = encoder_out[:, self.distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + N,T,_ = encoder_out.shape + codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes) + codebook_loss = codebook_loss.reshape(N,T,-1) + num_cb = codebook_loss.size(-1) + # normalize the loss by the number of codebooks + if reduction == "sum": + codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb # (B,) + elif reduction == "none": + codebook_loss = codebook_loss.sum(dim=2) / num_cb # (B,T) + else: + raise NotImplementedError() + + return codebook_loss + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + def apply_mask( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply mask according to the mask_mode, return the masked features and the masked positions + + Args: + x (torch.Tensor): The input fbank features + padding_mask (torch.Tensor, optional): The padding mask + + Returns: + The masked fbank feature and the masked_indices, with masked positions as 1 + """ + # apply mask to the fbank features, two modes applicable + if self.mask_mode == "w2v2": + x, masked_indices = self.apply_mask_w2v2(x, padding_mask) + elif self.mask_mode == "block": + x, masked_indices = self.apply_mask_block(x, padding_mask) + else: + raise NotImplementedError() + + if random.random() > 0.97: + logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked") + return x, masked_indices + + + def apply_mask_block( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ): + B,T,C = x.shape + assert self.mask_prob > 0.0 + + mask_indices = compute_mask_indices_block( + shape=(B,T), + padding_mask=padding_mask, + mask_prob=self.mask_prob, + mask_length=self.mask_length, + min_masks=self.min_masks, + ).to(x.device) + + x = index_put(x, mask_indices.bool(), self.mask_emb) + + return x, mask_indices + + def apply_mask_w2v2( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ): + # this function is modified from fairseq: https://github.com/facebookresearch/fairseq/blob/bedb259bf34a9fc22073c13a1cee23192fa70ef3/fairseq/models/wav2vec/wav2vec2.py#L429 + # The masked indices have value 1 + B, T, C = x.shape + + # we mask channel first, then mask timestamps + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=False, + min_space=1, + require_same_masks=False, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + if random.random() > 0.98: + logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked") + x[mask_channel_indices] = 0 + + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + mask_type=self.mask_selection, + mask_other=self.mask_other, + min_masks=2, # fixed + no_overlap=False, # False + min_space=1, # 1 + require_same_masks=False, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + mask_indices = mask_indices.float() + else: + mask_indices = None + + return x, mask_indices + + @staticmethod + def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): + # This function addresses the case where the teacher has a lower frame rate + # than the student model + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T) + codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) + codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C) + + assert codebook_indexes.shape[1] == middle_layer_output.shape[1] + return codebook_indexes + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ], + dim=1, + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes + +def index_put(tensor, indices, value): + tensor[indices] = value + return tensor + +def compute_mask_indices_block( + shape, + padding_mask, + mask_prob: float = 0.5, + mask_length: int = 10, + min_masks: int = 2, +): + # self-implemented mask, no overlap + B,T = shape + mask_indices = [] + for i in range(B): + if padding_mask is not None: + num_segments = (T - padding_mask[i].sum()) // mask_length # discard the last few frames + else: + num_segments = T // mask_length + segment_mask = torch.rand(num_segments) < mask_prob + while sum(segment_mask) < min_masks: + segment_mask = torch.rand(num_segments) < mask_prob + segment_mask_expanded = segment_mask.unsqueeze(-1).expand(num_segments, mask_length) + segment_mask_expanded = segment_mask_expanded.reshape(-1).float() + if segment_mask_expanded.size(0) < T: + pad = T - segment_mask_expanded.size(0) + segment_mask_expanded = torch.cat([segment_mask_expanded, torch.zeros(pad)]) + mask_indices.append(segment_mask_expanded) + + mask_indices = torch.stack(mask_indices) + return mask_indices + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + hard_max = sz // mask_length + num_mask = min(hard_max, num_mask) # prevent whole sequence being masked + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError("this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + +def _test_w2v2_channel_mask(): + x = torch.ones(100, 1000, 128) + B, T, C = x.shape + + configs = [(0.25, 15), (0.25, 20), (0.5, 15),] + # configs = [(0.2, 20), (0.3, 20), (0.4, 20),] + for config in configs: + mask_channel_prob, mask_channel_length = config + ratios = [] + for i in range(20): + mask_channel_indices = compute_mask_indices( + (B, C), + None, + mask_channel_prob, + mask_channel_length, + "static", + 0.0, + no_overlap=False, + min_space=1, + require_same_masks=False, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + ratio = mask_channel_indices.sum() / mask_channel_indices.numel() + ratios.append(ratio) + import pdb; pdb.set_trace() + avg_ratio = sum(ratios) / len(ratios) + print(f"Current config: mask_channel_prob = {mask_channel_prob}, mask_channel_length = {mask_channel_length}") + print(f"Averaged masking ratio: {avg_ratio}") + +def _test_w2v2_mask(): + x = torch.ones(100, 1000, 128) + B, T, C = x.shape + + mask_prob = 0.65 + mask_length = 10 + + # configs = [(0.65, 10), (0.01, 40), (0.1, 40), (0.2, 40), (0.2, 20), (0.35, 10), (0.35, 20), (0.25, 20)] + configs = [] + for i in range(6): + p = 0.05 + (i+1) * 0.1 + for l in [10, 20, 30, 40]: + configs.append((p, l)) + configs = [(0.65, 10), (0.02, 40), (0.05, 40), (0.1, 40)] + for config in configs: + mask_prob, mask_length = config + ratios = [] + for i in range(20): + mask_indices = compute_mask_indices( + (B, T), + None, + mask_prob, + mask_length, + mask_type="static", + mask_other=0.0, + min_masks=2, + no_overlap=False, # False + min_space=1, # 1 + require_same_masks=False, + ) + mask_indices = torch.from_numpy(mask_indices) + ratio = mask_indices.sum() / mask_indices.numel() + ratios.append(ratio) + avg_ratio = sum(ratios) / len(ratios) + print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") + print(f"Averaged masking ratio: {avg_ratio}") + +def _test_custom_mask(): + x = torch.ones(100, 1000, 128) + B, T, C = x.shape + + configs = [(0.5, 20), (0.2, 20), (0.3, 20), (0.4, 20), (0.5, 20)] + for config in configs: + mask_prob, mask_length = config + ratios = [] + for i in range(20): + all_possible_mask_lengths = [mask_length + i * 2 for i in range(-5, 6)] + mask_length = random.sample(all_possible_mask_lengths, 1)[0] + assert mask_length > 0, f"Sampled mask_length smaller than 0, {mask_length}" + + mask_indices = compute_mask_indices_block( + shape=(B, T), + padding_mask=None, + mask_prob=mask_prob, + mask_length=mask_length, + min_masks=2, + ) + import pdb; pdb.set_trace() + ratio = mask_indices.sum() / mask_indices.numel() + ratios.append(ratio) + avg_ratio = sum(ratios) / len(ratios) + print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") + print(f"Averaged masking ratio: {avg_ratio}") + +def _test_specaug_feature(): + pass + +if __name__=="__main__": + _test_w2v2_channel_mask() + # _test_w2v2_mask() + # _test_custom_mask() \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/model_multi_kd_w2v2_mask_inter_cb.py b/egs/emilia/CLAP/spear/model_multi_kd_w2v2_mask_inter_cb.py new file mode 100644 index 0000000000..5aede0c980 --- /dev/null +++ b/egs/emilia/CLAP/spear/model_multi_kd_w2v2_mask_inter_cb.py @@ -0,0 +1,850 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# Copyright 2025 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Optional, Tuple +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from multi_quantization.prediction import JointCodebookLoss + +from icefall.utils import make_pad_mask + + +class MultiKDModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + num_codebooks: int=8, + distillation_layer: int=9, + distillation_delta: int=0, + teacher_frame_ratio: int = 2, + interpolate_teacher: bool = False, + n_mels: int = 128, + num_events: int = 527, + mask_mode: str = "w2v2", + mask_prob: float = 0.65, + mask_length: int = 10, + mask_selection: str = "static", + mask_other: float = 0.0, + min_masks: int = 2, + mask_channel_prob: float = 0.0, + mask_channel_length: int = 10, + mask_channel_selection: str = "static", + mask_channel_other: float = 0.0, + loss_only_mask: bool = False, + normalize_fbank: bool = False, + intermediate_cb: bool = False, + intermediate_block_idx: int = -1, + ): + """A model that performs MVQ KD pre-training . + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + num_codebooks: + The number of codebooks used in the target + distillation_layer: + Use which layer to do MVQ pre-training + distillation_delta: + How many frames to delay the alignment between the model and the target frames. + Should be zero for non-streaming models, and a positive number for streaming models + teacher_frame_ratio: + The frame rate ratio between the target and the model output + mask_mode: + The masking mode. + w2v2: the wav2vec2 style of masking, allows overlap + custom: no overlap, therefore bigger masking ratio + mask_prob: + The probability of selecting choosing one frame as the start index + mask_length: + The length of each mask + mask_selection: + How to determine the length of the mask, see ``compute_mask_indices'' + normalize_fbank: + If true, the input fbank features is normalized to zero mean and unit variance + intermediate_cb: + Perform an extra intermediate codebook distillation + """ + super().__init__() + + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.distillation_layer = distillation_layer + # the frame ratio between the teacher and student + # if larger than one, we are basically having more than one set of + # codebooks for each frame + self.num_codebooks= num_codebooks + self.teacher_frame_ratio = teacher_frame_ratio + self.interpolate_teacher = interpolate_teacher + self.distillation_delta = distillation_delta + + self.intermediate_cb = intermediate_cb + self.intermediate_block_idx = intermediate_block_idx + assert intermediate_block_idx >= 0, "intermediate_block_idx should be a positive number" + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + reduction="none", + ) + if intermediate_cb: + # add an extra codebook loss for the intermediate layer + # note that we only support uniform dimension encoder so far + self.codebook_loss_net_inter = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks * self.teacher_frame_ratio, + is_joint=False, + reduction="none", + ) + else: + self.codebook_loss_net_inter = None + else: + self.codebook_loss_net = None + self.codebook_loss_net_inter = None + + self.audio_tagging_proj = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) # 527 classes + + # masking related + assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}" + self.mask_mode = mask_mode + + self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_()) + self.mask_prob = mask_prob + self.mask_length = mask_length + self.mask_selection = mask_selection + self.mask_other = mask_other + self.min_masks = min_masks + + self.mask_channel_prob = mask_channel_prob + self.mask_channel_length = mask_channel_length + self.mask_channel_selection = mask_channel_selection + self.mask_channel_other = mask_channel_other + + self.loss_only_mask = loss_only_mask + self.normalize_fbank = normalize_fbank + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask, return_middle_out=True) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens, middle_out + + @staticmethod + def _normalize_fbank(x: torch.Tensor, x_lens: torch.Tensor, eps: float=1e-9): + """ + x: (B, T, D) fbank 特征,已 padding 到同一 T + x_lens: (B,) 每条样本的有效帧数 (int) + """ + device = x.device + B, T, D = x.shape + + # mask: (B, T, 1) + mask = torch.arange(T, device=device).unsqueeze(0) < x_lens.unsqueeze(1) + mask = mask.unsqueeze(-1) # (B, T, 1), bool + + lengths = x_lens.view(B, 1, 1).to(x.dtype) # (B, 1, 1) + + # 均值 + sum_feats = (x * mask).sum(dim=1, keepdim=True) # (B, 1, D) + mean = sum_feats / lengths + + # 方差 + sum_sq = ((x - mean) * mask).pow(2).sum(dim=1, keepdim=True) + std = torch.sqrt(sum_sq / lengths + eps) + + # 归一化 + x_norm = (x - mean) / (std + eps) + # set masking positions to value 0 + x_norm = x_norm * mask + + return x_norm + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + codebook_indexes: torch.Tensor = None, + at_targets: torch.Tensor = None, + mask: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + codebook_indexes: + Codebook indexes of teacher embeddings + mask: + If we perform w2v2 style of masking over the fbank frames + + Returns: + Return the codebook loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert codebook_indexes is not None or at_targets is not None + + # normalise fbank (utterance level) + if self.normalize_fbank: + x = self._normalize_fbank(x, x_lens) + + # apply masking + if self.training and mask: + padding_mask = make_pad_mask(x_lens) + + # apply masking to the fbank features + x, mask_indices = self.apply_mask( + x.clone(), + padding_mask=padding_mask + ) # (N,T,C), (N,T) + else: + mask_indices = None + + # Compute encoder outputs + encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens) + + if codebook_indexes is not None and self.codebook_loss_net is not None: + codebook_loss = self.forward_codebook_loss( + encoder_out, encoder_out_lens, codebook_indexes, self.codebook_loss_net, reduction="none" + ) + if self.loss_only_mask and mask_indices is not None: + # downsample the mask + ds_mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5 + assert ds_mask_indices.size(1) >= codebook_loss.size(1) + ds_mask_indices = ds_mask_indices[:, :codebook_loss.size(1)].float() + codebook_loss = codebook_loss * ds_mask_indices + codebook_loss = codebook_loss.sum(dim=1) # (B,) + else: + codebook_loss = None + + if codebook_indexes is not None and self.intermediate_cb and self.codebook_loss_net_inter is not None: + # compute the codebook loss for the intermediate layer + middle_out = middle_out[self.intermediate_block_idx].permute(1,0,2) # (N,T,C) + codebook_loss_inter = self.forward_codebook_loss( + middle_out, encoder_out_lens, codebook_indexes, self.codebook_loss_net_inter, reduction="none" + ) + if self.loss_only_mask and mask_indices is not None: + ds_mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5 + assert ds_mask_indices.size(1) >= codebook_loss_inter.size(1) + ds_mask_indices = ds_mask_indices[:, :codebook_loss_inter.size(1)].float() + codebook_loss_inter = codebook_loss_inter * ds_mask_indices + codebook_loss_inter = codebook_loss_inter.sum(dim=1) + else: + codebook_loss_inter = None + + if at_targets is not None: + at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False) + else: + at_loss = None + + return codebook_loss, at_loss, codebook_loss_inter + + def forward_codebook_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + codebook_indexes: torch.Tensor, + codebook_loss_net: nn.Module, + reduction: str = "sum", + ): + # align the encoder features with the codebook indexes + if self.interpolate_teacher: + codebook_indexes = self.interpolate_codebook_indexes( + encoder_out, codebook_indexes + ) + else: + if codebook_indexes.shape[1] != encoder_out.shape[1]: + # align the codebook indexes to the frame rate of the student encoder out + codebook_indexes = self.concat_successive_codebook_indexes( + encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio + ) + + # the delta is associated with the frame-rate of the encoder + # so a bigger delta maybe necessary for 50Hz student encoder + if self.distillation_delta > 0: + codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] + encoder_out = encoder_out[:, self.distillation_delta:, :] + truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) + codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) + + N,T,_ = encoder_out.shape + codebook_loss = codebook_loss_net(encoder_out.float(), codebook_indexes) + codebook_loss = codebook_loss.reshape(N,T,-1) + num_cb = codebook_loss.size(-1) + # normalize the loss by the number of codebooks + if reduction == "sum": + codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb # (B,) + elif reduction == "none": + codebook_loss = codebook_loss.sum(dim=2) / num_cb # (B,T) + else: + raise NotImplementedError() + + return codebook_loss + + def forward_audio_tagging( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + target: torch.Tensor = None, + return_logits: bool = False, + ): + # target: (N, num_events) + logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) # (N,T) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events) + if return_logits: + return logits + + at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + + return at_loss + + def apply_mask( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply mask according to the mask_mode, return the masked features and the masked positions + + Args: + x (torch.Tensor): The input fbank features + padding_mask (torch.Tensor, optional): The padding mask + + Returns: + The masked fbank feature and the masked_indices, with masked positions as 1 + """ + # apply mask to the fbank features, two modes applicable + if self.mask_mode == "w2v2": + x, masked_indices = self.apply_mask_w2v2(x, padding_mask) + elif self.mask_mode == "block": + x, masked_indices = self.apply_mask_block(x, padding_mask) + else: + raise NotImplementedError() + + if random.random() > 0.97: + logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked") + return x, masked_indices + + + def apply_mask_block( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ): + B,T,C = x.shape + assert self.mask_prob > 0.0 + + mask_indices = compute_mask_indices_block( + shape=(B,T), + padding_mask=padding_mask, + mask_prob=self.mask_prob, + mask_length=self.mask_length, + min_masks=self.min_masks, + ).to(x.device) + + x = index_put(x, mask_indices.bool(), self.mask_emb) + + return x, mask_indices + + def apply_mask_w2v2( + self, + x: torch.Tensor, + padding_mask: torch.Tensor = None + ): + # this function is modified from fairseq: https://github.com/facebookresearch/fairseq/blob/bedb259bf34a9fc22073c13a1cee23192fa70ef3/fairseq/models/wav2vec/wav2vec2.py#L429 + # The masked indices have value 1 + B, T, C = x.shape + + # we mask channel first, then mask timestamps + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=False, + min_space=1, + require_same_masks=False, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + if random.random() > 0.98: + logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked") + x[mask_channel_indices] = 0 + + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + mask_type=self.mask_selection, + mask_other=self.mask_other, + min_masks=2, # fixed + no_overlap=False, # False + min_space=1, # 1 + require_same_masks=False, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + mask_indices = mask_indices.float() + else: + mask_indices = None + + return x, mask_indices + + @staticmethod + def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): + # This function addresses the case where the teacher has a lower frame rate + # than the student model + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T) + codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) + codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C) + + assert codebook_indexes.shape[1] == middle_layer_output.shape[1] + return codebook_indexes + + @staticmethod + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape # C should be 256 + + # Handling issue 1. + if T >= t_expected * ratio: + codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] + else: + assert t_expected * ratio - T <= 5, (T, t_expected, ratio) + diff = t_expected * ratio - T + codebook_indexes = torch.cat( + [ + codebook_indexes, + torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) + ], + dim=1, + ) + assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio + + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes + +def index_put(tensor, indices, value): + tensor[indices] = value + return tensor + +def compute_mask_indices_block( + shape, + padding_mask, + mask_prob: float = 0.5, + mask_length: int = 10, + min_masks: int = 2, +): + # self-implemented mask, no overlap + B,T = shape + mask_indices = [] + for i in range(B): + if padding_mask is not None: + num_segments = (T - padding_mask[i].sum()) // mask_length # discard the last few frames + else: + num_segments = T // mask_length + segment_mask = torch.rand(num_segments) < mask_prob + while sum(segment_mask) < min_masks: + segment_mask = torch.rand(num_segments) < mask_prob + segment_mask_expanded = segment_mask.unsqueeze(-1).expand(num_segments, mask_length) + segment_mask_expanded = segment_mask_expanded.reshape(-1).float() + if segment_mask_expanded.size(0) < T: + pad = T - segment_mask_expanded.size(0) + segment_mask_expanded = torch.cat([segment_mask_expanded, torch.zeros(pad)]) + mask_indices.append(segment_mask_expanded) + + mask_indices = torch.stack(mask_indices) + return mask_indices + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + hard_max = sz // mask_length + num_mask = min(hard_max, num_mask) # prevent whole sequence being masked + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError("this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + +def _test_w2v2_channel_mask(): + x = torch.ones(100, 1000, 128) + B, T, C = x.shape + + configs = [(0.25, 15), (0.25, 20), (0.5, 15),] + # configs = [(0.2, 20), (0.3, 20), (0.4, 20),] + for config in configs: + mask_channel_prob, mask_channel_length = config + ratios = [] + for i in range(20): + mask_channel_indices = compute_mask_indices( + (B, C), + None, + mask_channel_prob, + mask_channel_length, + "static", + 0.0, + no_overlap=False, + min_space=1, + require_same_masks=False, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + ratio = mask_channel_indices.sum() / mask_channel_indices.numel() + ratios.append(ratio) + import pdb; pdb.set_trace() + avg_ratio = sum(ratios) / len(ratios) + print(f"Current config: mask_channel_prob = {mask_channel_prob}, mask_channel_length = {mask_channel_length}") + print(f"Averaged masking ratio: {avg_ratio}") + +def _test_w2v2_mask(): + x = torch.ones(100, 1000, 128) + B, T, C = x.shape + + mask_prob = 0.65 + mask_length = 10 + + # configs = [(0.65, 10), (0.01, 40), (0.1, 40), (0.2, 40), (0.2, 20), (0.35, 10), (0.35, 20), (0.25, 20)] + configs = [] + for i in range(6): + p = 0.05 + (i+1) * 0.1 + for l in [10, 20, 30, 40]: + configs.append((p, l)) + configs = [(0.65, 10), (0.02, 40), (0.05, 40), (0.1, 40)] + for config in configs: + mask_prob, mask_length = config + ratios = [] + for i in range(20): + mask_indices = compute_mask_indices( + (B, T), + None, + mask_prob, + mask_length, + mask_type="static", + mask_other=0.0, + min_masks=2, + no_overlap=False, # False + min_space=1, # 1 + require_same_masks=False, + ) + mask_indices = torch.from_numpy(mask_indices) + ratio = mask_indices.sum() / mask_indices.numel() + ratios.append(ratio) + avg_ratio = sum(ratios) / len(ratios) + print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") + print(f"Averaged masking ratio: {avg_ratio}") + +def _test_custom_mask(): + x = torch.ones(100, 1000, 128) + B, T, C = x.shape + + configs = [(0.5, 20), (0.2, 20), (0.3, 20), (0.4, 20), (0.5, 20)] + for config in configs: + mask_prob, mask_length = config + ratios = [] + for i in range(20): + all_possible_mask_lengths = [mask_length + i * 2 for i in range(-5, 6)] + mask_length = random.sample(all_possible_mask_lengths, 1)[0] + assert mask_length > 0, f"Sampled mask_length smaller than 0, {mask_length}" + + mask_indices = compute_mask_indices_block( + shape=(B, T), + padding_mask=None, + mask_prob=mask_prob, + mask_length=mask_length, + min_masks=2, + ) + import pdb; pdb.set_trace() + ratio = mask_indices.sum() / mask_indices.numel() + ratios.append(ratio) + avg_ratio = sum(ratios) / len(ratios) + print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") + print(f"Averaged masking ratio: {avg_ratio}") + +def _test_specaug_feature(): + pass + +if __name__=="__main__": + _test_w2v2_channel_mask() + # _test_w2v2_mask() + # _test_custom_mask() \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/mtl_datamodule.py b/egs/emilia/CLAP/spear/mtl_datamodule.py new file mode 100644 index 0000000000..7edac0f797 --- /dev/null +++ b/egs/emilia/CLAP/spear/mtl_datamodule.py @@ -0,0 +1,1271 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 University of Cambridge (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache, cached_property +from pathlib import Path +from typing import Any, Dict, Optional, Union, List + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + ZipSampler, + SpecAugment, + WeightedSimpleCutSampler, + make_worker_init_fn, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from dataset import MultiTaskDataset +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MultiTaskDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--use-shar", + type=str2bool, + default=False, + ) + group.add_argument( + "--shar-dir", + type=Path, + default=Path("data-shar"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--zip-sampler", + type=str2bool, + default=False, + help="""If use a zip sampler to combine samplers from each task. + This cannot be used together with bucketing sampler. Only one of + them can be true.""" + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--time-mask-ratio", + type=float, + default=1.0, + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=-1, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--features-mask-size", + type=int, + default=27, + help="The maximum mask bins along the frequency axis in specaug" + ) + + group.add_argument( + "--frames-mask-size", + type=int, + default=100, + help="The maximum mask length along the time axis in specaug" + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # ASR related + group.add_argument( + "--use-librispeech", + type=str2bool, + default=True, + ) + + group.add_argument( + "--repeat-librispeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-gigaspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--gigaspeech-subset", + type=str, + default="m", + choices=["xs", "s", "m", "l", "xl"] + ) + + group.add_argument( + "--use-libriheavy", + type=str2bool, + default=False, + ) + + group.add_argument( + "--libriheavy-subset", + type=str, + default="medium", + ) + + group.add_argument( + "--use-wenetspeech", + type=str2bool, + default=False, + ) + + group.add_argument( + "--wenetspeech-subset", + type=str, + default="L", + ) + + group.add_argument( + "--repeat-wenetspeech", + type=int, + default=1, + ) + + group.add_argument( + "--use-mls", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-aishell", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-chinese-dataset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--use-extra-english-dataset", + type=str2bool, + default=False, + ) + + # KD related + group.add_argument( + "--mvq-KD", + type=str2bool, + default=False, + help="If load the codebook indexes instead of ground truth of audio events" + ) + + group.add_argument( + "--at-KD", + type=str2bool, + default=False, + help="If load the logits instead of ground truth of audio events" + ) + + group.add_argument( + "--sv-KD", + type=str2bool, + default=False, + help="If load speaker embedding instead of speaker identity" + ) + + # multi task dataset related + group.add_argument( + "--use-voxceleb", + type=str2bool, + default=False, + help="If use voxceleb as training set. This will not affet the model params.", + ) + + group.add_argument( + "--voxceleb-subset", + type=str, + default="vox1", + choices=["vox1", "vox2", "only_vox2"], + help="Which subset of voxceleb to use. If vox2, then vox1 and vox2 will be used.", + ) + + group.add_argument( + "--use-audioset", + type=str2bool, + default=False, + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "full"] + ) + + group.add_argument( + "--at-weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "This only applies to audio tagging", + ) + + group.add_argument( + "--at-num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler in AudioSet dataset", + ) + + group.add_argument( + "--repeat-audioset", + type=int, + default=1, + ) + + def train_dataloaders( + self, + cuts_train: Union[CutSet, Dict[str, CutSet]], + sampler_state_dict: Optional[Dict[str, Any]] = None, + sampling_weight: List[int] = None, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + # properly set world_size and rank + if self.args.use_shar: + logging.info(f"Setting world_size=1 and rank=0 because we will be using shar!") + world_size = 1 + rank = 0 + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz").drop_features() + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + num_frame_masks = int(10 * self.args.time_mask_ratio) + max_frames_mask_fraction = 0.15 * self.args.time_mask_ratio + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=self.args.features_mask_size, + num_feature_masks=2, + frames_mask_size=self.args.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, + ) + ) + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}, " + f"frames_mask_size: {self.args.frames_mask_size}, " + f"features_mask_size: {self.args.features_mask_size}" + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = MultiTaskDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MultiTaskDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + assert self.args.zip_sampler == False, "Cannot use ZipSampler when using Dynamic Bucketing sampler" + assert isinstance(cuts_train, CutSet), "DynamicBucketSampler only supports one training cuts" + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + elif self.args.zip_sampler: + logging.info(f"Using ZipSampler to combine multiple samplers") + assert len(cuts_train) > 1, "Can't use ZipSampler when only having one CutSet" + # By default, we use DynamicBucket sampler for non-audio-tagging dataset + # and if at_weighted_sampler=True, we use weighted sampler for audio tagging data + # By using the ZipSampler, we can alleviate the problem of unbalanced batching when + # using datasoures consisting of MULTIPLE tasks of very different durations (we only sample + # from a single bucket each time, and this bucket could be highly dominated by one task) + # However, this requires more careful setting of the max-duration for each sampler + # and the distribution of cuts in each batch is more difficult to control + assert isinstance(cuts_train, Dict), "ZipSampler requires multiple training cuts/samplers" + + samplers = [] + + for i, (name, cuts) in enumerate(cuts_train.items()): + # NOTE: The sampling weight should reflects the total duration of + # each cutset, as they will be higher likely to be exhausted at the same + # time + md = self.args.max_duration * sampling_weight[i]/ sum(sampling_weight) + logging.info(f"max duration for {name}: {md}") + if "audioset" not in name: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + sampler = WeightedSimpleCutSampler( + cuts, + weights, + num_samples=self.args.at_num_samples, + max_duration=md, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + sampler = DynamicBucketingSampler( + cuts, + max_duration=md, + shuffle=self.args.shuffle, + num_buckets=5, + buffer_size=10000, + shuffle_buffer_size=10000, + drop_last=self.args.drop_last, + ) + + samplers.append(sampler) + + train_sampler = ZipSampler( + *samplers, + merge_batches=True, + ) + else: + assert len(cuts_train) == 1, f"The training cuts contain {len(cuts_train)} cutsets" + cuts_train = list(cuts_train.values())[0] + if self.args.at_weighted_sampler: + weights = self.audioset_sampling_weights() + train_sampler = WeightedSimpleCutSampler( + cuts_train, + weights, + num_samples=self.args.at_num_samples, + max_duration=self.args.max_duration, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + if not self.args.use_shar: + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + else: + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + logging.info("Wrapping the dataset and sampler to an iterable") + + logging.info(f"World size: {train_sampler.world_size}") + logging.info(f"Rank: {train_sampler.rank}") + + rank = train_sampler.rank + world_size = train_sampler.world_size + + train_sampler.world_size = 1 + train_sampler.rank = 0 + + train_iter_dataset = IterableDatasetWrapper( + dataset=train, + sampler=train_sampler, + ) + + train_dl = DataLoader( + train_iter_dataset, + batch_size=None, + num_workers=self.args.num_workers, + worker_init_fn=make_worker_init_fn(seed=0, rank=rank, world_size=world_size), + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MultiTaskDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + else: + validate = MultiTaskDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + world_size: int = None, + rank: int = None, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = MultiTaskDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + mvq_KD=self.args.mvq_KD, + at_KD=self.args.at_KD, + sv_KD=self.args.sv_KD + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/train-all-shuf", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + if self.args.use_shar: + logging.info(f"Use share for librispeech dev-clean cuts") + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-clean", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + if self.args.use_shar: + return CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/librispeech/dev-other", + shuffle_shards=False, + ) + else: + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_train_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech training cuts") + gigaspeech_list = ["xs", "s", "m", "l", "xl"] + durations = [10, 240, 750, 1500, 7500] + assert self.args.gigaspeech_subset in gigaspeech_list, self.args.gigaspeech_subset + + all_cuts = CutSet() + all_cuts = [] + weights = [] + for i, subset in enumerate(gigaspeech_list): + logging.info(f"Loading gigaspeech cuts subset: {subset}") + weights.append(durations[i]) + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/{subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy(self.args.manifest_dir / f"gigaspeech_cuts_{subset}.jsonl.gz") + all_cuts.append(cuts) + if self.args.gigaspeech_subset == subset: + break + all_cuts = CutSet.mux( + *all_cuts, + weights=weights, + stop_early=False, + ) + + return all_cuts + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/gigaspeech/dev", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_dev.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_test.jsonl.gz") + + @lru_cache() + def libriheavy_train_cuts(self) -> CutSet: + logging.info(f"About to get libriheavy {self.args.libriheavy_subset} subset cuts") + if self.args.use_shar: + medium_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/libriheavy/medium", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + if self.args.libriheavy_subset == "medium": + return medium_cuts + else: + assert self.args.libriheavy_subset == "large" + large_cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/libriheavy/large", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = [medium_cuts, large_cuts] + return CutSet.mux( + *cuts, + weights=[1, 9], + stop_early=False, + ) + + else: + return load_manifest_lazy( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.libriheavy_subset}.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_train_cuts(self) -> CutSet: + logging.info(f"About to get wenetspeech {self.args.wenetspeech_subset} cuts") + if self.args.use_shar: + num_splits = 10 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/wenetspeech/L/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.wenetspeech_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def wenetspeech_valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/wenetspeech/DEV", + shuffle_shards=False, + ) + return cuts + else: + return load_manifest_lazy( + self.args.manifest_dir / "wenetspeech_cuts_DEV.jsonl.gz" + ) + + @lru_cache() + def wenetspeech_test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "wenetspeech_cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def aishell_train_cuts(self) -> CutSet: + logging.info("About to get aishell training cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_train.jsonl.gz") + + @lru_cache() + def aishell_dev_cuts(self) -> CutSet: + logging.info("About to get aishell dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") + + @lru_cache() + def aishell_test_cuts(self) -> CutSet: + logging.info("About to get aishell test cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") + + @lru_cache() + def mls_cuts(self) -> CutSet: + logging.info("About to get MLS cuts") + if self.args.use_shar: + num_splits = 8 + all_cuts = [] + for i in range(num_splits): + split_dir = f"{str(self.args.shar_dir)}/MLS/split_{i}" + logging.info(f"Loading {split_dir}") + cuts = CutSet.from_shar( + in_dir=split_dir, + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + cuts = cuts.resample(16000) + all_cuts.append(cuts) + return CutSet.mux( + *all_cuts, + weights=[1.0] * num_splits, + stop_early=False, + ).resample(16000) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"wenetspeech_cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def multi_english_cuts(self): + logging.info("About to get various English dataset cuts") + datasets = ["peoplespeech", "common_voice_20200622"] + datasets += ["en_us_english", "en8848", "ljspeech", "tatoeba", "ted", "vctk", "voase", "voaSplider"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of English speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def multi_chinese_cuts(self): + logging.info("About to get various Chinese dataset cuts") + datasets = ["accent", "aidatatang_200zh", "aishell3", "aishell2","baidu_en_cn","common_voice_20200622","datatang1505"] + datasets += ["dialog3k", "magicdata", "sensetime", "ximalaya", "acq", "cantonese", "cs_wav", "dialog"] + datasets += ["MagicData_dialog","primewords_md_2018_set1","zhvoice","phone","speech_wav"] + datasets += ["digital_library_202003", "ST-CMDS-20170001_1-OS", "20220309"] + all_cuts = [] + cuts_duration = [] + cuts_len = [] + for dataset in datasets: + logging.info(f"Loading {dataset}") + cuts = CutSet.from_shar( + in_dir=f"{self.args.shar_dir}/{dataset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + all_cuts.append(cuts) + cuts_duration.append(self.dataset_duration_stats[dataset]) + cuts_len.append(self.dataset_len_stats[dataset]) + + # alimeeting_cuts, ali_dur, ali_num_cuts = self.alimeeting_cuts() + # all_cuts.append(alimeeting_cuts) + # cuts_duration.append(ali_dur) + # cuts_len.append(ali_num_cuts) + + all_cuts = CutSet.mux( + *all_cuts, + weights=cuts_duration, + stop_early=False + ) + all_cuts = all_cuts.resample(16000) + all_duration = sum(cuts_duration) + all_len = sum(cuts_len) + # logging.info(f"Combining {datasets}") + logging.info(f"Getting a total of {all_duration} hours ({all_len} samples) of Chinese speech data. ") + return all_cuts, all_duration, all_len + + @lru_cache() + def alimeeting_cuts(self): + # alimeeting: 140 hrs, 186364 cuts + def reduce_supervisions(c): + supervisions = c.supervisions + supervisions = [supervisions[0]] + c.supervisions = supervisions + return c + logging.info("About to get the alimeeting cuts") + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/alimeeting/train", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "alimeeting-far_cuts_train.jsonl.gz" + ) + cuts = cuts.map(reduce_supervisions) + + return cuts.drop_features(), 140, 186364 + + @cached_property + def dataset_duration_stats(self): + stats_file = f"{self.args.shar_dir}/stats_duration.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = float(data[1]) + return stats + + @cached_property + def dataset_len_stats(self): + stats_file = f"{self.args.shar_dir}/stats_len.txt" + stats = {} + with open(stats_file, "r") as f: + for line in f: + data = line.strip().split() + stats[data[0]] = int(data[1]) + return stats + + @lru_cache() + def audioset_cuts(self) -> CutSet: + logging.info("About to get the audioset cuts.") + if self.args.audioset_subset == "full": + if not self.args.at_weighted_sampler: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/full", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + from lhotse import load_manifest + cuts = load_manifest( + self.args.manifest_dir / "audioset_cuts_full.jsonl.gz" + ) + else: + if self.args.use_shar: + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/{self.args.audioset_subset}", + shuffle_shards=True, + stateful_shuffle=True, + seed="randomized", + ).repeat() + else: + cuts = load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_balanced.jsonl.gz" + ) + return cuts.drop_features() + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + if self.args.use_shar: + logging.info(f"Use share for audioset eval cuts") + cuts = CutSet.from_shar( + in_dir=f"{str(self.args.shar_dir)}/audioset/eval", + shuffle_shards=False, + ) + return cuts + return load_manifest_lazy( + self.args.manifest_dir / "audioset_cuts_eval.jsonl.gz" + ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sampling_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights + + @lru_cache() + def voxceleb_cuts(self) -> CutSet: + # this should be used in KD + logging.info("About to get the voxceleb cuts.") + if self.args.voxceleb_subset == "only_vox2": + logging.info("Only get the voxceleb2 cuts.") + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_vox1_train.jsonl.gz" + ) + if self.args.voxceleb_subset == "vox2": + logging.info("Adding voxceleb2 cuts.") + cuts += load_manifest_lazy( + self.args.manifest_dir / "cuts_vox2_train.jsonl.gz" + ) + return cuts + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + MultiTaskDataModule.add_arguments(parser) + + args = parser.parse_args() + + mtl_datamodule = MultiTaskDataModule(args) + + from functools import partial + from utils import _add_dummy_embeddings_and_taskIDs + from lhotse import CutSet + cuts_path = "cuts.json" + cuts = CutSet.from_json(cuts_path) + asr_cuts = cuts.repeat(200) + asr_cuts = asr_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + cuts[0].id = cuts[0].id + "_at" + at_cuts = cuts.repeat(2000) + at_cuts = at_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) # ASR task ID=0 + at_cuts = at_cuts.to_eager() + sampling_weight = [300,100] + + train_cuts = { + "asr_cuts": asr_cuts, + "audio_tagging_cuts": at_cuts, + } + + train_dl = mtl_datamodule.train_dataloaders( + cuts_train=train_cuts, + sampling_weight=sampling_weight + ) + num_epochs = 3 + for epoch in range(1, num_epochs+1): + train_dl.sampler.set_epoch(epoch-1) + num1, num2 = 0, 0 + for batch_idx, batch in enumerate(train_dl): + task_ids = batch["task_ids"] + num1 += sum(task_ids == 1) + num2 += sum(task_ids == 2) + print(f"Epoch {epoch}, batch {batch_idx}: {sum(task_ids == 1)} {sum(task_ids == 2)}") + cuts = batch["supervisions"]["cut"] + if batch_idx == 0: + print([c.id for c in cuts]) + assert num2 <= args.at_num_samples + print(f"Number of cuts from task1: {num1}") + print(f"Number of cuts from task2: {num2}") + \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/optim.py b/egs/emilia/CLAP/spear/optim.py new file mode 100644 index 0000000000..9eff5c22d5 --- /dev/null +++ b/egs/emilia/CLAP/spear/optim.py @@ -0,0 +1,1184 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for (stacked_params, _state, _names), batch in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + assert "named_params" in cur_group + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for p, state, param_names in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + if tot_sumsq == 0.0: # for freezing parameters + return 1.0 + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + if median == 0.0: + # after freezing for a certain number of steps, start to optimize the + # parameters, they will have no accumulated stats, so the medium will be + # zero, this won't affect other parameters. + median = quartiles[-1] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 20.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans == 0.0: + for p, state, param_names in tuples: + p.grad.zero_() # get rid of infinity() + + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for p, state, batch_param_names in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + atol=1e-2, + ), sum([value[0] for value in all_sumsq_orig.values()]).cpu() + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad *= clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/emilia/CLAP/spear/scaling.py b/egs/emilia/CLAP/spear/scaling.py new file mode 120000 index 0000000000..6f398f431d --- /dev/null +++ b/egs/emilia/CLAP/spear/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/scaling_bf16.py b/egs/emilia/CLAP/spear/scaling_bf16.py new file mode 100644 index 0000000000..c307657e79 --- /dev/null +++ b/egs/emilia/CLAP/spear/scaling_bf16.py @@ -0,0 +1,1913 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# 2024 University of Cambridge (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import random +from typing import Optional, Tuple, Union + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + for x, y in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise linear + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p cross. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specify the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) + + def __float__(self): + batch_count = self.batch_count + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = x > self.cutoff + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1 - q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return x * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() + return x * scales + + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) + + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: + """Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + # half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Streaming Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + cache: cached left context of shape (batch_size, channels, left_pad) + """ + (batch_size, num_channels, seq_len) = x.shape + + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + # Pad cache + assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -left_pad:] + + x_causal = self.causal_conv(x) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size=seq_len) + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal, cache + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad = x_grad + (x_grad.abs() * loss_grad) + except Exception as e: + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + dtype = x_orig.dtype + x_detached = x_orig.detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = float(w.grad_scale) * ( + x_grad.to(dtype).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert float(grad_scale) >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + (ans,) = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + # return k2.swoosh_l_forward(x) + return SwooshLForward(x) + else: + # return k2.swoosh_l(x) + return SwooshLFunction.apply(x) # this support bf16 + + +class SwooshLOnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + if not x.requires_grad: + # return k2.swoosh_r_forward(x) + return SwooshRForward(x) + else: + # return k2.swoosh_r(x) + return SwooshRFunction.apply(x) + + +class SwooshROnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", l.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + self.dropout = Dropout3(dropout_p, shared_dim=dropout_shared_dim) + + def forward(self, x: Tensor): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + if self.activation == "SwooshL": + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, self.weight, self.bias) + + if self.activation == "SwooshL": + x = SwooshL()(x) + elif self.activation == "SwooshR": + x = SwooshR()(x) + + x = self.dropout(x) + return torch.nn.functional.linear(x, self.weight, self.bias) + + + # return ActivationDropoutAndLinearFunction.apply( + # x, + # self.weight, + # self.bias, + # self.activation, + # float(self.dropout_p), + # self.dropout_shared_dim, + # ) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear((0, 10.0)) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_double_swish_deriv() + _test_swooshr_deriv() + _test_swooshl_deriv() + _test_activation_dropout_and_linear() diff --git a/egs/emilia/CLAP/spear/subsampling.py b/egs/emilia/CLAP/spear/subsampling.py new file mode 120000 index 0000000000..01ae9002c6 --- /dev/null +++ b/egs/emilia/CLAP/spear/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/train_at_KD.py b/egs/emilia/CLAP/spear/train_at_KD.py new file mode 100644 index 0000000000..41638287f0 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_at_KD.py @@ -0,0 +1,1517 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = None + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + _, audio_tagging_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # AT loss + mask = task_ids == 2 # AT=2 + assert torch.all(mask) + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = ["/".join(str(c.shard_origin).split("/")[1:3]) for c in cuts] + unique_origin = set(shard_origin) + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + # if sh in shard_count: + # shard_count[sh] += 1 + # else: + # shard_count[sh] = 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + # logging.info(f"Batch {batch_idx}: Cuts stats: {shard_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + rank = setup_distributed() + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # audio data + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + assert params.bucketing_sampler, "Only support bucketing sampler in AT KD" + + train_dl = librispeech.train_dataloaders( + audioset_cuts, + sampler_state_dict=sampler_state_dict, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3.py b/egs/emilia/CLAP/spear/train_multi_KD3.py new file mode 100644 index 0000000000..15908dca34 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3.py @@ -0,0 +1,1618 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3 import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_dummy_embeddings_and_taskIDs, MetricsTracker + +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.02 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + # audio tagging label + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + loss += mvq_loss + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.8, # 122512 + "medium": 1050000 * 0.8, # 1093040 + } + libriheavy_cuts_duration = { + "small": 500 * 0.8, + "medium": 4154 * 0.8, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + # combine the asr data into a BIG cut + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=True, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + if params.use_audioset and params.do_audio_tagging: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 21.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + # TODO: add the wenetspeech valid cuts + pass + + if params.use_audioset and params.do_audio_tagging: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar.py new file mode 100644 index 0000000000..312f3c068f --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar.py @@ -0,0 +1,1760 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--mvq-loss-by-task", + type=str2bool, + default=True, + help="If True, only compute MVQ loss on the task from which the sample is drawn." + "Otherwise, ignore the task_ids and treat all data as if they come from the same task" + ) + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + else: + mvq_loss = mvq_loss.sum() + loss += mvq_loss + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + def estimate_cur_epoch(max_duration: float, world_size: int, steps: int, train_hrs: int): + estimated_hours = max_duration * world_size * steps / 3600 + estimated_epochs = estimated_hours // train_hrs + return estimated_epochs + + shard_count = {} + cur_epoch = 0 + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if params.use_shar: + est_epoch = estimate_cur_epoch( + params.max_duration, world_size, params.batch_idx_train, params.train_duration + ) + + if est_epoch > cur_epoch: + cur_epoch = est_epoch + scheduler.step_epoch(cur_epoch) # start from 1 + logging.info(f"Estimated epoch: {cur_epoch}") + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + unique_origin = set(shard_origin) + for ori in shard_origin: + if ori in shard_count: + shard_count[ori] += 1 + else: + shard_count[ori] = 1 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + # logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + logging.info(f"All shards source by far: {shard_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 466, + "medium": 4148, + "large": 42074, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + # assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + params.train_duration = sum(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 29.9: + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_co_training.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_co_training.py new file mode 100644 index 0000000000..9b7e65d153 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_co_training.py @@ -0,0 +1,1726 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.dataset import SpecAugment +from lhotse.utils import fix_random_seed +from model_multi_kd_co_training import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--use-co-training", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--co-training-loss-scale", + type=float, + default=0.2, + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--mvq-loss-by-task", + type=str2bool, + default=True, + help="If True, only compute MVQ loss on the task from which the sample is drawn." + "Otherwise, ignore the task_ids and treat all data as if they come from the same task" + ) + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> MultiKDModel: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}, " + f"features_mask_size: {params.features_mask_size}, " + f"frames_mask_size: {params.frames_mask_size}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # We don't do time warp for MVQ pre-training + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=params.features_mask_size, + num_feature_masks=2, + frames_mask_size=params.frames_mask_size, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + specaug: Optional[SpecAugment] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + # potentially perform specaug for co-training + use_co_training = params.use_co_training and is_training + use_spec_aug = use_co_training and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + + if use_co_training: + task_ids = task_ids.repeat(2) + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss, co_training_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + use_co_training=use_co_training, + use_spec_aug=use_spec_aug, + spec_augment=specaug, + supervision_segments=supervision_segments, + ) + + loss = 0.0 + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + else: + mvq_loss = mvq_loss.sum() + loss += mvq_loss + + # co-training loss is computed on all data + if use_co_training: + co_training_loss = co_training_loss.sum() + loss += co_training_loss * params.co_training_loss_scale + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + if use_co_training: + info["co_training_loss"] = co_training_loss.detach().cpu().item() + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + specaug: Optional[SpecAugment] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + unique_origin = set(shard_origin) + for ori in shard_origin: + if ori in shard_count: + shard_count[ori] += 1 + else: + shard_count[ori] = 1 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + logging.info(f"All shards source by far: {shard_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + specaug=specaug, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + rank = setup_distributed() + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # Create a standalone specaugment module, we won't do + # specaug in the dataloading process + if params.use_co_training: + assert params.enable_spec_aug == False, "When performing co-training, we apply specaugment inside the forward function of the model" + specaug = get_spec_augment(params) + logging.info("We will perform spec augment for co-training") + else: + specaug = None + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 466, + "medium": 4148, + "large": 42074, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + # combine the asr data into a BIG cut + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + if params.use_audioset and params.do_audio_tagging: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 29.9: + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + specaug=specaug, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_inter_cb_w2v2_mask.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_inter_cb_w2v2_mask.py new file mode 100644 index 0000000000..a60a88c681 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_inter_cb_w2v2_mask.py @@ -0,0 +1,2001 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut, MonoCut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_w2v2_mask_inter_cb import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # mvq related + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--intermediate-cb", + type=str2bool, + default=True, + help="If True, compute the codebook loss for the intermediate layer" + ) + + parser.add_argument( + "--intermediate-block-idx", + type=int, + default=3, + help="The block index of the intermediate codebook loss." + ) + + parser.add_argument( + "--intermediate-cb-loss-scale", + type=float, + default=0.1, + help="The scale for the intermediate codebook loss." + ) + + parser.add_argument( + "--mvq-loss-by-task", + type=str2bool, + default=True, + help="If True, only compute MVQ loss on the task from which the sample is drawn." + "Otherwise, ignore the task_ids and treat all data as if they come from the same task" + ) + + # masking related + parser.add_argument( + "--loss-only-mask", + type=str2bool, + default=False, + help="If True, only compute loss on the masked indices" + ) + + parser.add_argument( + "--mask-mode", + type=str, + default="w2v2", + choices=["w2v2", "block"], + help="The masking mode", + ) + + parser.add_argument( + "--mask-length", type=int, default=10, help="mask_length" + ) + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--mask-channel-length", type=int, default=15, help="mask_length" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a channel with mask", + ) + + # normalization + parser.add_argument( + "--normalize-fbank", + type=str2bool, + default=False, + help="If perform normalization to the input fbank features" + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-hours", + type=float, + default=20000, + help="""Number of hours trained speech that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + parser.add_argument( + "--estimate-epoch", + type=str2bool, + default=True, + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + params.subsampling_factor = 2 + assert params.enable_spec_aug == False, "Should not use specaug when using w2v2 style masking" + if params.loss_only_mask: + logging.info("Only computing loss on the masked positions") + if params.normalize_fbank: + logging.info("Normalizing the input fbank features") + + if params.intermediate_cb: + assert params.intermediate_block_idx >= 0, "intermediate_block_idx should be non-negative" + assert params.intermediate_block_idx < len(_to_int_tuple(params.num_encoder_layers)) + assert params.intermediate_block_idx != len(_to_int_tuple(params.num_encoder_layers)) -1, "Don't use the last block for intermediate codebook loss" + logging.info(f"Using intermediate codebook loss at block {params.intermediate_block_idx}") + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + n_mels=params.feature_dim, + mask_mode=params.mask_mode, + mask_prob=params.mask_prob, + mask_length=params.mask_length, + mask_selection=params.mask_selection, + mask_other=params.mask_other, + mask_channel_prob=params.mask_channel_prob, + mask_channel_length=params.mask_channel_length, + loss_only_mask=params.loss_only_mask, + normalize_fbank=params.normalize_fbank, + intermediate_cb=params.intermediate_cb, + intermediate_block_idx=params.intermediate_block_idx, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + cuts = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss, inter_mvq_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + else: + mvq_loss = mvq_loss.sum() + loss += mvq_loss + + if params.intermediate_cb: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + inter_mvq_loss = (inter_mvq_loss * mask).sum() + else: + inter_mvq_loss = inter_mvq_loss.sum() + loss += params.intermediate_cb_loss_scale * inter_mvq_loss + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.intermediate_cb: + info["inter_mvq_loss"] = inter_mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + def estimate_cur_epoch(max_duration: float, world_size: int, steps: int, train_hrs: int): + estimated_hours = max_duration * world_size * steps / 3600 + estimated_epochs = estimated_hours // train_hrs + return estimated_epochs + + shard_count = {} + shard_durations_count = {} + cur_epoch = 0 + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if params.use_shar and params.estimate_epoch: + est_epoch = estimate_cur_epoch( + params.max_duration, world_size, params.batch_idx_train, params.train_duration, + ) + if est_epoch > cur_epoch: + cur_epoch = est_epoch + # scheduler.step_epoch(cur_epoch) # start from 1 + logging.info(f"Estimated epoch: {cur_epoch}") + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + cuts = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + durations = [c.duration for c in cuts] + unique_origin = set(shard_origin) + for ori, dur in zip(shard_origin, durations): + if ori in shard_count: + shard_count[ori] += 1 + shard_durations_count[ori] += dur / 3600 + else: + shard_count[ori] = 1 + shard_durations_count[ori] = dur / 3600 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 100 == 1: + logging.info(f"All shards source by far: {shard_count}") + logging.info(f"All shard duration by far: {shard_durations_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + # Use the number of hours of speech to adjust the learning rate + scheduler.step_epoch( + params.batch_idx_train * params.max_duration * params.world_size / 3600 + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + # the max iteration criteria should be applied to both shar and non-shar + if params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + num_param_prediction_head = sum([p.numel() for p in model.codebook_loss_net.parameters()]) + logging.info(f"Number of encoder parameters: {num_param - num_param_prediction_head}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_hours, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 118334, + "medium": 1062926, + "large": 10796160, + } + libriheavy_cuts_duration = { + "small": 473, + "medium": 4208 + 473, + "large": 42683 + 4208 + 473, # 47364 hrs + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_voxpopuli: + voxpopuli_cuts = librispeech.voxpopuli_unlabelled_cuts() + voxpopuli_cuts = voxpopuli_cuts.map(partial(_add_task_id, 1)) + # vox en unlabelled: 24151 hrs, 3059813 cuts + asr_training_cuts.append(voxpopuli_cuts) + asr_training_cuts_lens.append(3059813) + asr_training_cuts_duration.append(24151) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_emotion_dataset: + other_emotion_cuts = librispeech.multi_emotion_cuts() + msp_podcast_cuts = librispeech.msp_podcast_train_cust() + emotion_cuts = CutSet.mux( + *[other_emotion_cuts, msp_podcast_cuts], + weights=[134, 52], + stop_early=False, + ) + emotion_cuts = emotion_cuts.map(partial(_add_task_id, 1)) # for now we treat ER cuts as part of ASR cuts + asr_training_cuts.append(emotion_cuts) + asr_training_cuts_lens.append(130297 * params.repeat_emo) # 46267 + 84030 + asr_training_cuts_duration.append(186 * params.repeat_emo) # 52 + 134 + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + # assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + logging.info(f"ASR cuts: {asr_training_cuts}") + logging.info(f"ASR cuts length: {asr_training_cuts_lens}") + logging.info(f"ASR cuts duration: {asr_training_cuts_duration}") + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + audio_training_cuts = [] + audio_training_cuts_lens = [] + audio_training_cuts_duration = [] + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + audio_training_cuts.append(audioset_cuts) + audio_training_cuts_lens.append(num_audio_cuts) + audio_training_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + if params.use_music4all: + music4all_cuts = librispeech.music4all_cuts() # 910 hrs, 109269 cuts, 30s + music4all_cuts = music4all_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(music4all_cuts) + audio_training_cuts_lens.append(109269 * params.repeat_music4all) + audio_training_cuts_duration.append(910 * params.repeat_music4all) + + if params.use_vggsound: + vggsound_cuts = librispeech.vggsound_train_cuts() # 427 hrs, 154142 cuts + vggsound_cuts = vggsound_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(vggsound_cuts) + audio_training_cuts_lens.append(154142 * params.repeat_vggsound) + audio_training_cuts_duration.append(427 * params.repeat_vggsound) + + if params.use_bbceffect: + # split into 10s + bbceffect_cuts = librispeech.bbc_soundeffect_train_cuts() # 430 hrs, 160905 cuts + bbceffect_cuts = bbceffect_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(bbceffect_cuts) + audio_training_cuts_lens.append(160905) + audio_training_cuts_duration.append(430) + + if params.use_freesound: + # split into 10s + freesound_cuts = librispeech.freesound_train_cuts() # 2516 hrs, 1073093 cuts + freesound_cuts = freesound_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(freesound_cuts) + audio_training_cuts_lens.append(1073093) + audio_training_cuts_duration.append(2516) + + if params.use_mtg: + # split into 10s + mtg_cuts = librispeech.mtg_cuts() # + mtg_cuts = mtg_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(mtg_cuts) + audio_training_cuts_lens.append(1032727) + audio_training_cuts_duration.append(2812) + + # combine the audio datasets + if len(audio_training_cuts) >= 1: + logging.info(f"audio cuts: {audio_training_cuts}") + logging.info(f"audio cuts length: {audio_training_cuts_lens}") + logging.info(f"audio cuts duration: {audio_training_cuts_duration}") + if len(audio_training_cuts) > 1: + audio_training_cuts = CutSet.mux( + *audio_training_cuts, + weights=audio_training_cuts_lens, + stop_early=False, + ) + else: + audio_training_cuts = audio_training_cuts[0] + + train_cuts["cuts_audio"] = audio_training_cuts + train_cuts_duration.append(sum(audio_training_cuts_duration)) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + params.train_duration = sum(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 31: + return False + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_emotion_dataset: + msp_podcast_dev_cuts = librispeech.msp_podcast_dev_cust() + msp_podcast_dev_cuts = msp_podcast_dev_cuts.map(partial(_add_task_id, 1)) + er_msp_dev_dl = librispeech.valid_dataloaders(msp_podcast_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ER_msp_podcast") + valid_dls.append(er_msp_dev_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + if params.use_vggsound: + vggsound_eval_cuts = librispeech.vggsound_test_cuts() + vggsound_eval_cuts = vggsound_eval_cuts.map(partial(_add_task_id, 2)) + vggsound_valid_dl = librispeech.valid_dataloaders(vggsound_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_vggsound") + valid_dls.append(vggsound_valid_dl) + + if params.use_bbceffect: + bbc_test_cuts = librispeech.bbc_soundeffect_test_cuts() + bbc_test_cuts = bbc_test_cuts.map(partial(_add_task_id, 2)) + bbc_test_dl = librispeech.valid_dataloaders(bbc_test_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_bbc") + valid_dls.append(bbc_test_dl) + + # if params.use_freesound: + # freesound_test_cuts = librispeech.freesound_test_cuts() + # freesound_test_cuts = freesound_test_cuts.map(partial(_add_task_id, 2)) + # freesound_test_dl = librispeech.valid_dataloaders(freesound_test_cuts, world_size=world_size, rank=rank,) + # valid_sets.append("AT_freesound") + # valid_dls.append(freesound_test_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + # scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_joint_loss.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_joint_loss.py new file mode 100644 index 0000000000..4c72192c52 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_joint_loss.py @@ -0,0 +1,1754 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--mvq-loss-by-task", + type=str2bool, + default=True, + help="If True, only compute MVQ loss on the task from which the sample is drawn." + "Otherwise, ignore the task_ids and treat all data as if they come from the same task" + ) + + parser.add_argument( + "--audio-sample-loss-scale", + type=float, + default=1.0, + help="Scale down the loss computed from the audio data" + ) + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + else: + loss_mask = torch.ones(len(task_ids)).to(device) + loss_mask[task_ids == 2] = params.audio_sample_loss_scale + mvq_loss = (mvq_loss * loss_mask).sum() + + loss += mvq_loss + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + unique_origin = set(shard_origin) + for ori in shard_origin: + if ori in shard_count: + shard_count[ori] += 1 + else: + shard_count[ori] = 1 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + logging.info(f"All shards source by far: {shard_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 466, + "medium": 4148, + "large": 42074, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + # assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 29.9: + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_mae.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_mae.py new file mode 100644 index 0000000000..a68e096b4d --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_mae.py @@ -0,0 +1,1865 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.distributed as dist +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_mae import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import ( + _add_task_id, + MetricsTracker, + setup_distributed, +) + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + # encoder related + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + # decoder related arguments + parser.add_argument( + "--num-decoder-layers", + type=str, + default="2,2,2,2,2,2", + help="Number of zipformer decoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--decoder-dim", + type=str, + default="256,256,256,256,256,256", + ) + + parser.add_argument( + "--decoder-downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--decoder-feedforward-dim", + type=str, + default="768,768,768,768,768,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--decoder-unmasked-dim", + type=str, + default="192,192,192,192,192,192", + help="Unmasked dimensions in the decoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--decoder-num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=False, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--mae-downsample-factor", + type=int, + default=4, + help="""The final downsample factor after the decoder. This includes both + the downsampling at encoder and decoder. + """ + ) + + parser.add_argument( + "--mae-loss-norm", + type=str, + default="sample", + choices=["sample", "batch", "frame"] + ) + + parser.add_argument( + "--mae-loss-scale", + type=float, + default=0.1, + help="The loss scale for MAE loss" + ) + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + # TODO: make this applicable to more than two losses + parser.add_argument( + "--speech-mvq-loss-scale", + type=float, + default=1.0, + help="The scale of whisper mvq losses" + ) + + parser.add_argument( + "--audio-mvq-loss-scale", + type=float, + default=1.0, + help="The scale of dasheng mvq losses" + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--save-with-client", + type=str2bool, + default=False, + help="If True, save the model to s3 client" + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + + num_params = sum([p.numel() for p in encoder.parameters()]) + logging.info(f"Number of parameters in the encoder: {num_params}") + return encoder + +def get_decoder(params: AttributeDict) -> nn.Module: + decoder = Zipformer2( + output_downsampling_factor=1, # this is fixed, we don't want further downsample + num_encoder_layers=_to_int_tuple(params.num_decoder_layers), + downsampling_factor=_to_int_tuple(params.decoder_downsampling_factor), + encoder_dim=_to_int_tuple(params.decoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.decoder_unmasked_dim), + feedforward_dim=_to_int_tuple(params.decoder_feedforward_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.decoder_num_heads), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + causal=False, + ) + + num_params = sum([p.numel() for p in decoder.parameters()]) + logging.info(f"Number of parameters in the decoder: {num_params}") + + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + assert params.mae_downsample_factor == params.output_downsampling_factor * 2 + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + decoder_dim=max(_to_int_tuple(params.decoder_dim)), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + n_mels=params.feature_dim, + mae_loss_norm=params.mae_loss_norm, + mae_downsample_factor=params.mae_downsample_factor, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss, mae_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + fbank_target=feature, + ) + + loss = 0.0 + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss, first is whisper MVQ, second is Dasheng MVQ + if params.do_mvq: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + loss += mvq_loss + + mae_loss = (mae_loss * mask).sum() + loss += mae_loss * params.mae_loss_scale + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + info["mae_loss"] = mae_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = ["/".join(str(c.shard_origin).split("/")[1:3]) for c in cuts] + unique_origin = set(shard_origin) + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + # save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + rank = setup_distributed() + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 466, + "medium": 4148, + "large": 42074, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + else: + asr_training_cuts = CutSet() + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # general audio data + if params.do_audio_tagging: + assert params.use_audioset, "If we do audio tagging, we must use audioset" + + def change_codebook_indexes(c): + c.audio_codebook_indexes = c.codebook_indexes + del c.codebook_indexes + return c + + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + audioset_cuts = audioset_cuts.map(change_codebook_indexes) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 29.0: + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + as_eval_cuts = as_eval_cuts.map(change_codebook_indexes) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_multi_mvq.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_multi_mvq.py new file mode 100644 index 0000000000..ab0484a867 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_multi_mvq.py @@ -0,0 +1,1767 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar_multi_teacher import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_multi_teacher import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=str, + default="-1,-1", + ) + + parser.add_argument( + "--distillation-delta", + type=str, + default="0,0", + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=str, + default="2,2", + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=str, + default="16,16", + ) + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + # TODO: make this applicable to more than two losses + parser.add_argument( + "--disjoint", + type=str2bool, + default=False, + help="""Default False, which means compute multi-mvq losses on all data. + Otherwise only compute loss on ASR data""" + ) + + parser.add_argument( + "--mvq-loss-scales", + type=str, + default="1.0,1.0", + help="The scale of two mvq losses" + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=_to_int_tuple(params.num_codebooks), + distillation_layer=_to_int_tuple(params.distillation_layer), + distillation_delta=_to_int_tuple(params.distillation_delta), + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=_to_int_tuple(params.teacher_frame_ratio), + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"] + mvq_tokens = [tokens.to(device) for tokens in mvq_tokens] + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + mvq_losses = losses[:-1] + audio_tagging_loss = losses[-1] + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + mvq_loss_scales = tuple(map(float, params.mvq_loss_scales.split(","))) + mvq_loss_values = [] + if params.do_mvq: + # mask = task_ids == 1 # ASR=1 + import pdb; pdb.set_trace() + if not params.disjoint: + mask = task_ids != 0 # compute loss for all data + else: + mask = task_ids == 1 # ASR=1 + for mvq_loss, scale in zip(mvq_losses, mvq_loss_scales): + mvq_loss = (mvq_loss * mask).sum() + mvq_loss_values.append(mvq_loss) + loss += mvq_loss * scale # TODO: make this an option + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + teachers = ["whisper", "dasheng"] + for i, mvq_loss in enumerate(mvq_loss_values): + info[f"{teachers[i]}_mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = ["/".join(str(c.shard_origin).split("/")[1:3]) for c in cuts] + unique_origin = set(shard_origin) + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + logging.info(count) + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 500 * 0.9, + "medium": 3687, + "large": 37218, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + else: + asr_training_cuts = CutSet() + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) # sum of [] is 0 + + # general audio data + if params.do_audio_tagging: + assert params.use_audioset, "If we do audio tagging, we must use audioset" + + def change_codebook_indexes(c): + c.codebook_indexes2 = c.codebook_indexes + del c.codebook_indexes + return c + + # audio data + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + audioset_cuts = audioset_cuts.map(change_codebook_indexes) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 29.0: + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_multi_mvq2.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_multi_mvq2.py new file mode 100644 index 0000000000..66f5fec149 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_multi_mvq2.py @@ -0,0 +1,1892 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import io +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar_multi_teacher2 import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_multi_teacher import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +import torch.distributed as dist + +from utils import _add_task_id, _add_language_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=str, + default="-1,-1", + ) + + parser.add_argument( + "--distillation-delta", + type=str, + default="0,0", + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=str, + default="2,2", + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=str, + default="16,16", + ) + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + # TODO: make this applicable to more than two losses + parser.add_argument( + "--mvq-loss-scales", + type=str, + default="0.5,0.5", + help="The scale of two mvq losses" + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-with-client", + type=str2bool, + default=True, + help="If True, save the model to s3 client" + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 4000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=_to_int_tuple(params.num_codebooks), + distillation_layer=_to_int_tuple(params.distillation_layer), + distillation_delta=_to_int_tuple(params.distillation_delta), + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=_to_int_tuple(params.teacher_frame_ratio), + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + is_zh = torch.tensor([c.language_id == "zh" for c in cuts]).to(device) + is_en = torch.tensor([c.language_id == "en" for c in cuts]).to(device) + language_masks = [is_en, is_zh] + + if is_training: + if params.batch_idx_train % 100 == 0: + logging.info(f"Step: {params.batch_idx_train}, zh data: {is_zh.sum()}, en data: {is_en.sum()}") + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"] + mvq_tokens = [tokens.to(device) for tokens in mvq_tokens] + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + mvq_losses = losses[:-1] # (whisper_mvq, firered_mvq) + audio_tagging_loss = losses[-1] + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + mvq_loss_scales = tuple(map(float, params.mvq_loss_scales.split(","))) + mvq_loss_values = [] + if params.do_mvq: + mask = task_ids == 1 # ASR=1 + for mvq_loss, scale, language_mask in zip(mvq_losses, mvq_loss_scales, language_masks): + mvq_loss = (mvq_loss * mask * language_mask).sum() + mvq_loss_values.append(mvq_loss) + loss += mvq_loss * scale # TODO: make this an option + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + teachers = ["whisper", "firered"] + for i, mvq_loss in enumerate(mvq_loss_values): + info[f"{teachers[i]}_mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=0, + ) + + shard_count = {} + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + + for ori in shard_origin: + if ori in shard_count: + shard_count[ori] += 1 + else: + shard_count[ori] = 1 + # unique_origin = set(shard_origin) + # count = {orig: 0 for orig in unique_origin} + # for sh in shard_origin: + # count[sh] += 1 + + if batch_idx % 100 == 1: + shard_epoch = [int(c.shar_epoch) for c in cuts] + max_epoch = max(shard_epoch) + logging.info(f"Estimated epoch is {max_epoch}") + # logging.info(count) + logging.info(f"All shards source by far: {shard_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # this is a modified version of saving checkpoint + _save_checkpoint_with_global_batch_idx( + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + if world_size > 1: + dist.barrier() + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + Note that this is the global rank. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() # this will setup the device + else: + local_rank = 0 + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + if world_size > 1: + logging.info(f"Global Rank: {dist.get_rank()}, Local Rank: {local_rank}, CUDA Device Count: {torch.cuda.device_count()}") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) # this should be local rank + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + logging.info(f"Inialisting the model avg") + model_avg = copy.deepcopy(model).to(torch.float64) + + # save the model to client + if rank == 0 and params.save_with_client: + from petrel_client.client import Client + conf_path = "/mnt/petrelfs/share_data/housiyuan/petreloss.conf" + client = Client(conf_path) + params.client = client + else: + params.client = None + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + # NOTE: since the chinese data are usually much shorter than english data + # using len for mux sampling might be not appropriate, we decided to combine + # english and chinese cuts separately, and then mux with a fixed weight + + en_asr_training_cuts = [] + en_asr_training_cuts_lens = [] + en_asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 28539 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + en_asr_training_cuts.append(librispeech_cuts) + en_asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + en_asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + en_asr_training_cuts.append(gigaspeech_cuts) + en_asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + en_asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 500 * 0.9, + "medium": 3687, + "large": 37218, + } + en_asr_training_cuts.append(libriheavy_cuts) + en_asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + en_asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + # mls cuts: 10801 hrs, 2619190 cuts + en_asr_training_cuts.append(mls_cuts) + en_asr_training_cuts_lens.append(2619190) + en_asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + en_asr_training_cuts.append(englishs_cuts) + en_asr_training_cuts_lens.append(english_cuts_len) + en_asr_training_cuts_duration.append(english_cut_durations) + + # combine the English asr data into a BIG cut + assert len(en_asr_training_cuts) >= 1, len(en_asr_training_cuts) + if len(en_asr_training_cuts) > 1: + en_asr_training_cuts = CutSet.mux( + *en_asr_training_cuts, + weights=en_asr_training_cuts_lens, + stop_early=False, + ) + else: + en_asr_training_cuts = en_asr_training_cuts[0] + + en_asr_training_cuts = en_asr_training_cuts.map(partial(_add_language_id, "en")) + logging.info(f"Total English data: {sum(en_asr_training_cuts_duration)} hours, {sum(en_asr_training_cuts_lens)} cuts.") + + # from now on, Chinese cuts + def change_codebook_indexes(c): + if c.has_custom("firered_codebook_indexes"): # if the cut already has firered cb, do nothing + del c.codebook_indexes + return c + else: + assert c.has_custom("codebook_indexes") + c.firered_codebook_indexes = c.codebook_indexes + del c.codebook_indexes + return c + + zh_asr_training_cuts = [] + zh_asr_training_cuts_lens = [] + zh_asr_training_cuts_duration = [] + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(change_codebook_indexes) + zh_asr_training_cuts.append(wenetspeech_cuts) + zh_asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + zh_asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_weread: + weread_cuts, weread_cut_durations, weread_cuts_len = librispeech.weread_dataset_cuts() + weread_cuts = weread_cuts.map(change_codebook_indexes) + zh_asr_training_cuts.append(weread_cuts) + zh_asr_training_cuts_lens.append(weread_cuts_len) + zh_asr_training_cuts_duration.append(weread_cut_durations) + + if params.use_extra_chinese_dataset: + chinese_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chinese_cuts = chinese_cuts.map(change_codebook_indexes) + zh_asr_training_cuts.append(chinese_cuts) + zh_asr_training_cuts_lens.append(chinese_cuts_len) + zh_asr_training_cuts_duration.append(chinese_cut_durations) + + # combine the chinese asr data into a BIG cut + assert len(zh_asr_training_cuts) >= 1, len(zh_asr_training_cuts) + if len(zh_asr_training_cuts) > 1: + zh_asr_training_cuts = CutSet.mux( + *zh_asr_training_cuts, + weights=zh_asr_training_cuts_lens, + stop_early=False, + ) + else: + zh_asr_training_cuts = zh_asr_training_cuts[0] + logging.info( + f"Total Chinese data: {sum(zh_asr_training_cuts_duration)} hours,\ + {sum(zh_asr_training_cuts_lens)} cuts." + ) + + zh_asr_training_cuts = zh_asr_training_cuts.map(partial(_add_language_id, "zh")) + + # combine the en and zh ASR data + asr_training_cuts = [en_asr_training_cuts, zh_asr_training_cuts] + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=[1.0,1.0], + stop_early=False + ) + asr_training_cuts = asr_training_cuts.map(partial(_add_task_id, 1)) + + asr_training_cuts_duration = sum(en_asr_training_cuts_duration) + sum(zh_asr_training_cuts_duration) + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(asr_training_cuts_duration) + + # audio data + if params.use_audioset and params.do_audio_tagging: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + def change_to_s3(c): + source = c.recording.sources[0].source + source = source.replace("download/", "brainllm:s3://yangxiaoyu/") + c.recording.sources[0].source = source + c.recording.sources[0].type = "url" + return c + + audioset_cuts = audioset_cuts.map(change_to_s3) + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + audioset_cuts = audioset_cuts.map(partial(_add_language_id, "none")) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + train_cuts["cuts_audioset"] = audioset_cuts + train_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 29.0: + return False + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + logging.info(f"World size: {world_size}, rank: {rank}") + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + ls_valid_cuts = ls_valid_cuts.map(partial(_add_language_id, "en")) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + giga_dev_cuts = giga_dev_cuts.map(partial(_add_language_id, "en")) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_language_id, "zh")) + wenet_dev_cuts = wenet_dev_cuts.map(change_codebook_indexes) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_audioset and params.do_audio_tagging: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(change_to_s3) + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + as_eval_cuts = as_eval_cuts.map(partial(_add_language_id, "none")) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + +def _save_checkpoint_with_global_batch_idx( + params, + model, + optimizer = None, + sampler = None, + scheduler = None, + scaler = None, + model_avg = None, + rank: int = 0, +): + # only active when rank==0 + if rank != 0: + return + + if isinstance(model, DDP): + model = model.module + else: + model = model + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "scheduler": scheduler.state_dict() if scheduler is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + output_path = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" + + if params.save_with_client: + logging.info(f"Saving checkpoint to {output_path}") + with io.BytesIO() as f: + output_path = "brainllm:s3://yangxiaoyu/" + str(output_path) + torch.save(checkpoint, f) + f.seek(0) + params.client.put(output_path, f) + logging.info(f"Finish saving checkpoint to {output_path}") + else: + logging.info(f"Saving checkpoint to {output_path}") + torch.save(checkpoint, output_path) + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_speech_audio_multi_mvq.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_speech_audio_multi_mvq.py new file mode 100644 index 0000000000..5117d6120d --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_speech_audio_multi_mvq.py @@ -0,0 +1,2025 @@ + #!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar_speech_audio_multi_teacher import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_multi_teacher import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=str, + default="-1,-1", + ) + + parser.add_argument( + "--distillation-delta", + type=str, + default="0,0", + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=str, + default="2,2", + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=str, + default="16,16", + ) + + # masking related + parser.add_argument( + "--loss-only-mask", + type=str2bool, + default=False, + help="If True, only compute loss on the masked indices" + ) + + parser.add_argument( + "--mask-mode", + type=str, + default="w2v2", + choices=["w2v2", "block"], + help="The masking mode", + ) + + parser.add_argument( + "--mask-length", type=int, default=10, help="mask_length" + ) + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--mask-channel-length", type=int, default=15, help="mask_length" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a channel with mask", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + # TODO: make this applicable to more than two losses + parser.add_argument( + "--speech-mvq-loss-scale", + type=float, + default=1.0, + help="The scale of speech mvq losses" + ) + + parser.add_argument( + "--audio-mvq-loss-scale", + type=float, + default=1.0, + help="The scale of audio mvq losses" + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + + assert params.enable_spec_aug == False, "Should not use specaug when using w2v2 style masking" + if params.loss_only_mask: + logging.info("Only computing loss on the masked positions") + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=_to_int_tuple(params.num_codebooks), + distillation_layer=_to_int_tuple(params.distillation_layer), + distillation_delta=_to_int_tuple(params.distillation_delta), + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=_to_int_tuple(params.teacher_frame_ratio), + n_mels=params.feature_dim, + mask_mode=params.mask_mode, + mask_prob=params.mask_prob, + mask_length=params.mask_length, + mask_selection=params.mask_selection, + mask_other=params.mask_other, + mask_channel_prob=params.mask_channel_prob, + mask_channel_length=params.mask_channel_length, + loss_only_mask=params.loss_only_mask, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + cut_ids = [c.id for c in cuts] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"] + mvq_tokens = [tokens.to(device) for tokens in mvq_tokens] + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + speech_mvq_loss, audio_mvq_loss = losses[:-1] + audio_tagging_loss = losses[-1] + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss, first is whisper MVQ, second is Dasheng MVQ + mvq_loss_values = [] + if params.do_mvq: + speech_mask = task_ids == 1 # ASR data task_id=1 + num_speech_frames = feature_lens[speech_mask].sum() // 4 # equivalent frames + if torch.isnan(speech_mvq_loss).any(): # filter the nan loss + logging.info(f"Detected NaN in speech mvq loss") + speech_mvq_loss = torch.nan_to_num(speech_mvq_loss, nan=0.0) + speech_mvq_loss = (speech_mvq_loss * speech_mask).sum() + mvq_loss_values.append(speech_mvq_loss) + # loss += speech_mvq_loss/ (num_speech_frames + 1) * params.speech_mvq_loss_scale # TODO: make this an option + loss += speech_mvq_loss * params.speech_mvq_loss_scale # TODO: make this an option + + audio_mask = task_ids == 2 + num_audio_frames = feature_lens[audio_mask].sum() // 4 # equivalent frames + + # if num_audio_frames == 0: + # correction_factor = 0.0 + # elif num_speech_frames == 0: + # correction_factor = 1.0 + # else: + # correction_factor = num_speech_frames / num_audio_frames + # if random.random() < 0.02: + # logging.info(f"Correction factor: {correction_factor}") + correction_factor = 1.0 + if torch.isnan(audio_mvq_loss).any(): + logging.info(f"Detected NaN in audio mvq loss") + audio_mvq_loss = torch.nan_to_num(audio_mvq_loss, nan=0.0) + audio_mvq_loss = (audio_mvq_loss * audio_mask).sum() + mvq_loss_values.append(audio_mvq_loss) # the un-normalized loss + # loss += (audio_mvq_loss / (num_audio_frames + 1)) * params.audio_mvq_loss_scale # TODO: make this an option + loss += audio_mvq_loss * correction_factor * params.audio_mvq_loss_scale # TODO: make this an option + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker(normalize=True) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + teachers = ["speech", "audio"] + for i, mvq_loss in enumerate(mvq_loss_values): + info[f"{teachers[i]}_mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + # logging.info(f"Batch: {params.batch_idx_train}: speech mvq loss: {speech_mvq_loss}, num_frames: {num_speech_frames}") + # logging.info(f"Batch: {params.batch_idx_train}: audio mvq loss: {audio_mvq_loss}, num_frames: {num_audio_frames}") + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker(normalize=True) + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=0, + ) + + def estimate_cur_epoch(max_duration: float, world_size: int, steps: int, train_hrs: int): + estimated_hours = max_duration * world_size * steps / 3600 + estimated_epochs = estimated_hours // train_hrs + return estimated_epochs + + shard_count = {} + cur_epoch = 0 + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + if params.use_shar: + est_epoch = estimate_cur_epoch( + params.max_duration, world_size, params.batch_idx_train, params.train_duration, + ) + if est_epoch > cur_epoch: + cur_epoch = est_epoch + scheduler.step_epoch(cur_epoch) # start from 1 + logging.info(f"Estimated epoch: {cur_epoch}") + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + unique_origin = set(shard_origin) + for ori in shard_origin: + if ori in shard_count: + shard_count[ori] += 1 + else: + shard_count[ori] = 1 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 2 == 1: + task_ids = batch["task_ids"] + num_speech_cuts = sum(task_ids == 1).item() + speech_duration = sum([c.duration for c in cuts if c.task_id == 1]) + num_audio_cuts = sum(task_ids == 2).item() + audio_duration = sum([c.duration for c in cuts if c.task_id == 2]) + logging.info(f"batch {batch_idx}: task cuts: {num_speech_cuts}, {num_audio_cuts}, task durations: {speech_duration}, {audio_duration}") + # logging.info(count) + logging.info(f"All shards source by far: {shard_count}") + continue + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + # save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + # for valid_set, valid_dl in zip(valid_sets, valid_dls): + # logging.info("Computing validation loss") + # valid_info = compute_validation_loss( + # params=params, + # model=model, + # sp=sp, + # valid_dl=valid_dl, + # world_size=world_size, + # ) + + # logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + # logging.info( + # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + # ) + # if tb_writer is not None: + # valid_info.write_summary( + # tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + # ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs, avg dur 4.2s + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset] * params.repeat_gigaspeech) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset] * params.repeat_gigaspeech) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 466, + "medium": 4148, + "large": 42074, # avg dur: 15s + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_mls: + mls_cuts = librispeech.mls_train_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 6000 hrs, 1409826 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(1409826) + asr_training_cuts_duration.append(6000) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_emotion_dataset: + other_emotion_cuts = librispeech.multi_emotion_cuts() + msp_podcast_cuts = librispeech.msp_podcast_train_cust() + emotion_cuts = CutSet.mux( + *[other_emotion_cuts, msp_podcast_cuts], + weights=[134, 52], + stop_early=False, + ) + emotion_cuts = emotion_cuts.map(partial(_add_task_id, 1)) # for now we treat ER cuts as part of ASR cuts + asr_training_cuts.append(emotion_cuts) + asr_training_cuts_lens.append(130297 * params.repeat_emo) # 46267 + 84030 + asr_training_cuts_duration.append(186 * params.repeat_emo) # 52 + 134 + + if params.use_fisher: + fisher_cuts = librispeech.fisher_cuts() + fisher_cuts = fisher_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 2041 hrs, 2113438 cuts + asr_training_cuts.append(fisher_cuts) + asr_training_cuts_lens.append(2113438) + asr_training_cuts_duration.append(2041) + + if params.use_voxpopuli: + # multi-lingual data + if params.voxpopuli_subset == "en_v2": + voxpopuli_cuts = librispeech.voxpopuli_unlabelled_cuts() + asr_training_cuts_lens.append(3059813) + asr_training_cuts_duration.append(24151) # avg dur: 28.4 + else: + voxpopuli_cuts = librispeech.voxpopuli_asr_train_cuts() + asr_training_cuts_lens.append(526497) + asr_training_cuts_duration.append(1636) + + voxpopuli_cuts = voxpopuli_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(voxpopuli_cuts) + + # combine the asr data into a BIG cut + assert len(asr_training_cuts) >= 1 + if len(asr_training_cuts) >= 1: + logging.info(f"ASR cuts: {asr_training_cuts}") + logging.info(f"ASR cuts length: {asr_training_cuts_lens}") + logging.info(f"ASR cuts duration: {asr_training_cuts_duration}") + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # general audio data + if params.do_audio_tagging: + assert params.use_audioset, "If we do audio tagging, we must use audioset" + + def change_codebook_indexes(c): + c.audio_codebook_indexes = c.codebook_indexes + del c.codebook_indexes + return c + + # audio data + audio_training_cuts = [] + audio_training_cuts_lens = [] + audio_training_cuts_duration = [] + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + audioset_cuts = audioset_cuts.map(change_codebook_indexes) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + audio_training_cuts.append(audioset_cuts) + audio_training_cuts_lens.append(num_audio_cuts) + audio_training_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + if params.use_music4all: + # all 30s cuts + music4all_cuts = librispeech.music4all_cuts() # 910 hrs, 109269 cuts + music4all_cuts = music4all_cuts.map(partial(_add_task_id, 2)) + music4all_cuts = music4all_cuts.map(change_codebook_indexes) + audio_training_cuts.append(music4all_cuts) + audio_training_cuts_lens.append(109269 * params.repeat_music4all) + audio_training_cuts_duration.append(910 * params.repeat_music4all) + + if params.use_vggsound: + # all 10s cuts + vggsound_cuts = librispeech.vggsound_train_cuts() # 427 hrs, 154142 cuts + vggsound_cuts = vggsound_cuts.map(partial(_add_task_id, 2)) + vggsound_cuts = vggsound_cuts.map(change_codebook_indexes) + audio_training_cuts.append(vggsound_cuts) + audio_training_cuts_lens.append(154142 * params.repeat_vggsound) + audio_training_cuts_duration.append(427 * params.repeat_vggsound) + + if params.use_bbceffect: + # split into 10s + bbceffect_cuts = librispeech.bbc_soundeffect_train_cuts() # 430 hrs, 160905 cuts + bbceffect_cuts = bbceffect_cuts.map(partial(_add_task_id, 2)) + bbceffect_cuts = bbceffect_cuts.map(change_codebook_indexes) + audio_training_cuts.append(bbceffect_cuts) + audio_training_cuts_lens.append(160905) + audio_training_cuts_duration.append(430) + + if params.use_freesound: + # split into 10s, so all cuts <=10s + freesound_cuts = librispeech.freesound_train_cuts() # 2811 hrs, 1028645 cuts + freesound_cuts = freesound_cuts.map(partial(_add_task_id, 2)) + freesound_cuts = freesound_cuts.map(change_codebook_indexes) + audio_training_cuts.append(freesound_cuts) + audio_training_cuts_lens.append(1073093) + audio_training_cuts_duration.append(2516) + + if params.use_mtg: + # split into 10s + mtg_cuts = librispeech.mtg_cuts() # + mtg_cuts = mtg_cuts.map(partial(_add_task_id, 2)) + mtg_cuts = mtg_cuts.map(change_codebook_indexes) + audio_training_cuts.append(mtg_cuts) + audio_training_cuts_lens.append(1032727) + audio_training_cuts_duration.append(2812) + + # combine the audio datasets + assert len(audio_training_cuts) >= 1 + if len(audio_training_cuts) >= 1: + logging.info(f"audio cuts: {audio_training_cuts}") + logging.info(f"audio cuts length: {audio_training_cuts_lens}") + logging.info(f"audio cuts duration: {audio_training_cuts_duration}") + if len(audio_training_cuts) > 1: + audio_training_cuts = CutSet.mux( + *audio_training_cuts, + weights=audio_training_cuts_lens, + stop_early=False, + ) + else: + audio_training_cuts = audio_training_cuts[0] + + train_cuts["cuts_audio"] = audio_training_cuts + train_cuts_duration.append(sum(audio_training_cuts_duration)) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + params.train_duration = sum(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + # because we have some music cuts, the duration is 30 second + if c.duration < 0.9 or c.duration > 31.0: + return False + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + assert len(train_cuts) == 2, "We should only have speech and audio cuts" + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=[2.0, 1.0], + # weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_emotion_dataset: + msp_podcast_dev_cuts = librispeech.msp_podcast_dev_cust() + msp_podcast_dev_cuts = msp_podcast_dev_cuts.map(partial(_add_task_id, 1)) + er_msp_dev_dl = librispeech.valid_dataloaders(msp_podcast_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ER_msp_podcast") + valid_dls.append(er_msp_dev_dl) + + if params.use_voxpopuli and params.voxpopuli_subset != "en_v2": + voxpopuli_dev_cuts = librispeech.voxpopuli_dev_cuts() + voxpopuli_dev_cuts = voxpopuli_dev_cuts.map(partial(_add_task_id, 1)) + asr_voxpopuli_dev_dl = librispeech.valid_dataloaders(voxpopuli_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_voxpopuli") + valid_dls.append(asr_voxpopuli_dev_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + as_eval_cuts = as_eval_cuts.map(change_codebook_indexes) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + if params.use_vggsound: + vggsound_eval_cuts = librispeech.vggsound_test_cuts() + vggsound_eval_cuts = vggsound_eval_cuts.map(partial(_add_task_id, 2)) + vggsound_eval_cuts = vggsound_eval_cuts.map(change_codebook_indexes) + vggsound_valid_dl = librispeech.valid_dataloaders(vggsound_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_vggsound") + valid_dls.append(vggsound_valid_dl) + + if params.use_bbceffect: + bbc_test_cuts = librispeech.bbc_soundeffect_test_cuts() + bbc_test_cuts = bbc_test_cuts.map(partial(_add_task_id, 2)) + bbc_test_cuts = bbc_test_cuts.map(change_codebook_indexes) + bbc_test_dl = librispeech.valid_dataloaders(bbc_test_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_bbc") + valid_dls.append(bbc_test_dl) + + # if params.use_freesound: + # freesound_test_cuts = librispeech.freesound_test_cuts() + # freesound_test_cuts = freesound_test_cuts.map(partial(_add_task_id, 2)) + # freesound_test_cuts = freesound_test_cuts.map(change_codebook_indexes) + # freesound_test_dl = librispeech.valid_dataloaders(freesound_test_cuts, world_size=world_size, rank=rank,) + # valid_sets.append("AT_freesound") + # valid_dls.append(freesound_test_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_w2v2_mask.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_w2v2_mask.py new file mode 100644 index 0000000000..8150bc26f9 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_w2v2_mask.py @@ -0,0 +1,1962 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut, MonoCut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_w2v2_mask import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # mvq related + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--mvq-loss-by-task", + type=str2bool, + default=True, + help="If True, only compute MVQ loss on the task from which the sample is drawn." + "Otherwise, ignore the task_ids and treat all data as if they come from the same task" + ) + + # masking related + parser.add_argument( + "--loss-only-mask", + type=str2bool, + default=False, + help="If True, only compute loss on the masked indices" + ) + + parser.add_argument( + "--mask-mode", + type=str, + default="w2v2", + choices=["w2v2", "block"], + help="The masking mode", + ) + + parser.add_argument( + "--mask-length", type=int, default=10, help="mask_length" + ) + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--mask-channel-length", type=int, default=15, help="mask_length" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a channel with mask", + ) + + # normalization + parser.add_argument( + "--normalize-fbank", + type=str2bool, + default=False, + help="If perform normalization to the input fbank features" + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-hours", + type=float, + default=20000, + help="""Number of hours trained speech that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + parser.add_argument( + "--estimate-epoch", + type=str2bool, + default=True, + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + params.subsampling_factor = 2 + assert params.enable_spec_aug == False, "Should not use specaug when using w2v2 style masking" + if params.loss_only_mask: + logging.info("Only computing loss on the masked positions") + if params.normalize_fbank: + logging.info("Normalizing the input fbank features") + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + n_mels=params.feature_dim, + mask_mode=params.mask_mode, + mask_prob=params.mask_prob, + mask_length=params.mask_length, + mask_selection=params.mask_selection, + mask_other=params.mask_other, + mask_channel_prob=params.mask_channel_prob, + mask_channel_length=params.mask_channel_length, + loss_only_mask=params.loss_only_mask, + normalize_fbank=params.normalize_fbank, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + cuts = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + else: + mvq_loss = mvq_loss.sum() + loss += mvq_loss + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + def estimate_cur_epoch(max_duration: float, world_size: int, steps: int, train_hrs: int): + estimated_hours = max_duration * world_size * steps / 3600 + estimated_epochs = estimated_hours // train_hrs + return estimated_epochs + + shard_count = {} + shard_durations_count = {} + cur_epoch = 0 + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if params.use_shar and params.estimate_epoch: + est_epoch = estimate_cur_epoch( + params.max_duration, world_size, params.batch_idx_train, params.train_duration, + ) + if est_epoch > cur_epoch: + cur_epoch = est_epoch + # scheduler.step_epoch(cur_epoch) # start from 1 + logging.info(f"Estimated epoch: {cur_epoch}") + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + cuts = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + durations = [c.duration for c in cuts] + unique_origin = set(shard_origin) + for ori, dur in zip(shard_origin, durations): + if ori in shard_count: + shard_count[ori] += 1 + shard_durations_count[ori] += dur / 3600 + else: + shard_count[ori] = 1 + shard_durations_count[ori] = dur / 3600 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 100 == 1: + logging.info(f"All shards source by far: {shard_count}") + logging.info(f"All shard duration by far: {shard_durations_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + # Use the number of hours of speech to adjust the learning rate + scheduler.step_epoch( + params.batch_idx_train * params.max_duration * params.world_size / 3600 + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + # the max iteration criteria should be applied to both shar and non-shar + if params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + num_param_prediction_head = sum([p.numel() for p in model.codebook_loss_net.parameters()]) + logging.info(f"Number of encoder parameters: {num_param - num_param_prediction_head}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_hours, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 118334, + "medium": 1062926, + "large": 10796160, + } + libriheavy_cuts_duration = { + "small": 473, + "medium": 4208 + 473, + "large": 42683 + 4208 + 473, # 47364 hrs + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_voxpopuli: + voxpopuli_cuts = librispeech.voxpopuli_unlabelled_cuts() + voxpopuli_cuts = voxpopuli_cuts.map(partial(_add_task_id, 1)) + # vox en unlabelled: 24151 hrs, 3059813 cuts + asr_training_cuts.append(voxpopuli_cuts) + asr_training_cuts_lens.append(3059813) + asr_training_cuts_duration.append(24151) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_emotion_dataset: + other_emotion_cuts = librispeech.multi_emotion_cuts() + msp_podcast_cuts = librispeech.msp_podcast_train_cust() + emotion_cuts = CutSet.mux( + *[other_emotion_cuts, msp_podcast_cuts], + weights=[134, 52], + stop_early=False, + ) + emotion_cuts = emotion_cuts.map(partial(_add_task_id, 1)) # for now we treat ER cuts as part of ASR cuts + asr_training_cuts.append(emotion_cuts) + asr_training_cuts_lens.append(130297 * params.repeat_emo) # 46267 + 84030 + asr_training_cuts_duration.append(186 * params.repeat_emo) # 52 + 134 + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + # assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + logging.info(f"ASR cuts: {asr_training_cuts}") + logging.info(f"ASR cuts length: {asr_training_cuts_lens}") + logging.info(f"ASR cuts duration: {asr_training_cuts_duration}") + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + audio_training_cuts = [] + audio_training_cuts_lens = [] + audio_training_cuts_duration = [] + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + audio_training_cuts.append(audioset_cuts) + audio_training_cuts_lens.append(num_audio_cuts) + audio_training_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + if params.use_music4all: + music4all_cuts = librispeech.music4all_cuts() # 910 hrs, 109269 cuts, 30s + music4all_cuts = music4all_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(music4all_cuts) + audio_training_cuts_lens.append(109269 * params.repeat_music4all) + audio_training_cuts_duration.append(910 * params.repeat_music4all) + + if params.use_vggsound: + vggsound_cuts = librispeech.vggsound_train_cuts() # 427 hrs, 154142 cuts + vggsound_cuts = vggsound_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(vggsound_cuts) + audio_training_cuts_lens.append(154142 * params.repeat_vggsound) + audio_training_cuts_duration.append(427 * params.repeat_vggsound) + + if params.use_bbceffect: + # split into 10s + bbceffect_cuts = librispeech.bbc_soundeffect_train_cuts() # 430 hrs, 160905 cuts + bbceffect_cuts = bbceffect_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(bbceffect_cuts) + audio_training_cuts_lens.append(160905) + audio_training_cuts_duration.append(430) + + if params.use_freesound: + # split into 10s + freesound_cuts = librispeech.freesound_train_cuts() # 2516 hrs, 1073093 cuts + freesound_cuts = freesound_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(freesound_cuts) + audio_training_cuts_lens.append(1073093) + audio_training_cuts_duration.append(2516) + + if params.use_mtg: + # split into 10s + mtg_cuts = librispeech.mtg_cuts() # + mtg_cuts = mtg_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(mtg_cuts) + audio_training_cuts_lens.append(1032727) + audio_training_cuts_duration.append(2812) + + # combine the audio datasets + if len(audio_training_cuts) >= 1: + logging.info(f"audio cuts: {audio_training_cuts}") + logging.info(f"audio cuts length: {audio_training_cuts_lens}") + logging.info(f"audio cuts duration: {audio_training_cuts_duration}") + if len(audio_training_cuts) > 1: + audio_training_cuts = CutSet.mux( + *audio_training_cuts, + weights=audio_training_cuts_lens, + stop_early=False, + ) + else: + audio_training_cuts = audio_training_cuts[0] + + train_cuts["cuts_audio"] = audio_training_cuts + train_cuts_duration.append(sum(audio_training_cuts_duration)) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + params.train_duration = sum(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 31: + return False + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_emotion_dataset: + msp_podcast_dev_cuts = librispeech.msp_podcast_dev_cust() + msp_podcast_dev_cuts = msp_podcast_dev_cuts.map(partial(_add_task_id, 1)) + er_msp_dev_dl = librispeech.valid_dataloaders(msp_podcast_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ER_msp_podcast") + valid_dls.append(er_msp_dev_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + if params.use_vggsound: + vggsound_eval_cuts = librispeech.vggsound_test_cuts() + vggsound_eval_cuts = vggsound_eval_cuts.map(partial(_add_task_id, 2)) + vggsound_valid_dl = librispeech.valid_dataloaders(vggsound_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_vggsound") + valid_dls.append(vggsound_valid_dl) + + if params.use_bbceffect: + bbc_test_cuts = librispeech.bbc_soundeffect_test_cuts() + bbc_test_cuts = bbc_test_cuts.map(partial(_add_task_id, 2)) + bbc_test_dl = librispeech.valid_dataloaders(bbc_test_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_bbc") + valid_dls.append(bbc_test_dl) + + # if params.use_freesound: + # freesound_test_cuts = librispeech.freesound_test_cuts() + # freesound_test_cuts = freesound_test_cuts.map(partial(_add_task_id, 2)) + # freesound_test_dl = librispeech.valid_dataloaders(freesound_test_cuts, world_size=world_size, rank=rank,) + # valid_sets.append("AT_freesound") + # valid_dls.append(freesound_test_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + # scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_multi_KD3_shar_w2v2_mask_token_mixing.py b/egs/emilia/CLAP/spear/train_multi_KD3_shar_w2v2_mask_token_mixing.py new file mode 100644 index 0000000000..9fc38bd774 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_multi_KD3_shar_w2v2_mask_token_mixing.py @@ -0,0 +1,1934 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 University of Cambridge (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +from functools import partial +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from kd_datamodule3_shar_token_mixing import MultiTaskDataModule +from lhotse import CutSet +from lhotse.cut import Cut, MonoCut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_multi_kd_w2v2_mask import MultiKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from utils import _add_task_id, MetricsTracker, setup_distributed + +from zipformer2 import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=False, + help="Whether to fine-tune.", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help=""" + Modules to be frozen. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # mvq related + parser.add_argument( + "--do-mvq", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--do-audio-tagging", + type=str2bool, + default=True, + help="If do audio tagging multi task training" + ) + + parser.add_argument( + "--do-speaker-verification", + type=str2bool, + default=False, + help="If do speaker verification" + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--interpolate-teacher", + type=str2bool, + default=False, + help="""This should only be used when the teacher has a lower frame rate + than the student model. We use interpolation to find the nearest neighbour""" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + parser.add_argument( + "--mvq-loss-by-task", + type=str2bool, + default=True, + help="If True, only compute MVQ loss on the task from which the sample is drawn." + "Otherwise, ignore the task_ids and treat all data as if they come from the same task" + ) + + # masking related + parser.add_argument( + "--loss-only-mask", + type=str2bool, + default=False, + help="If True, only compute loss on the masked indices" + ) + + parser.add_argument( + "--mask-mode", + type=str, + default="w2v2", + choices=["w2v2", "block"], + help="The masking mode", + ) + + parser.add_argument( + "--mask-length", type=int, default=10, help="mask_length" + ) + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--mask-channel-length", type=int, default=15, help="mask_length" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a channel with mask", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--max-iters", + type=int, + default=200000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="multi_task/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0 + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--audio-tagging-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--speaker-verification-loss-scale", + type=float, + default=1.0, + help="Scale for audio tagging loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + help="If stop early if using mux" + ) + + parser.add_argument( + "--estimate-epoch", + type=str2bool, + default=True, + ) + + add_finetune_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, # for better audio capability + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for multitask + "num_tasks": 2, + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.interpolate_teacher: + logging.warning(f"Interpolate the teacher indexes to match the length of the student") + assert params.teacher_frame_ratio == 1 + + if params.output_downsampling_factor == 1: + logging.info(f"Setting the output downsample factor to 1.") + if params.teacher_frame_ratio > 1: + logging.warning( + f"You are using teacher_frame_ratio={params.teacher_frame_ratio}. " + "However, the output downsampling factor is 1. This could be wrong!" + ) + assert params.enable_spec_aug == False, "Should not use specaug when using w2v2 style masking" + if params.loss_only_mask: + logging.info("Only computing loss on the masked positions") + + model = MultiKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + interpolate_teacher=params.interpolate_teacher, + teacher_frame_ratio=params.teacher_frame_ratio, + n_mels=params.feature_dim, + mask_mode=params.mask_mode, + mask_prob=params.mask_prob, + mask_length=params.mask_length, + mask_selection=params.mask_selection, + mask_other=params.mask_other, + mask_channel_prob=params.mask_channel_prob, + mask_channel_length=params.mask_channel_length, + loss_only_mask=params.loss_only_mask, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + assert params.start_epoch == 1 + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + logging.info(f"Loading {key} from init ckpt") + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + feature_lens = supervisions["num_frames"].to(device) + task_ids = batch["task_ids"].int().to(device) + + if random.random() < 0.01 and is_training: + cuts = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + for t in range(1, params.num_tasks+1): + duration = sum([c.duration for c in cuts if c.task_id == t]) + logging.info(f"Number of samples from task {t}: {sum(task_ids == t).item()}/{len(task_ids)}") + logging.info(f"Total duration of task {t}: {duration}") + + # mvq tokens + mvq_tokens = batch["cb_indexes"].to(device) + + # audio tagging label + if params.do_audio_tagging: + at_targets = batch["at_targets"].to(device) # the label indices are in CED format + else: + at_targets = None + + with torch.set_grad_enabled(is_training): + mvq_loss, audio_tagging_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + at_targets=at_targets, + ) + + loss = 0.0 + + # task_id=1: ASR data + # task_id=2: AT data + + # MVQ loss + if params.do_mvq: + if params.mvq_loss_by_task: + mask = task_ids == 1 # ASR=1 + mvq_loss = (mvq_loss * mask).sum() + else: + mvq_loss = mvq_loss.sum() + loss += mvq_loss + + # AT loss + if params.do_audio_tagging: + mask = task_ids == 2 # AT=2 + audio_tagging_loss = (audio_tagging_loss.sum(dim=-1) * mask).sum() # this also works if mask is all False + loss += params.audio_tagging_loss_scale * audio_tagging_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["utterances"] = task_ids.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.do_mvq: + info["mvq_loss"] = mvq_loss.detach().cpu().item() + if params.do_audio_tagging: + info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + def estimate_cur_epoch(max_duration: float, world_size: int, steps: int, train_hrs: int): + estimated_hours = max_duration * world_size * steps / 3600 + estimated_epochs = estimated_hours // train_hrs + return estimated_epochs + + shard_count = {} + cur_epoch = 0 + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if params.use_shar and params.estimate_epoch: + est_epoch = estimate_cur_epoch( + params.max_duration, world_size, params.batch_idx_train, params.train_duration, + ) + if est_epoch > cur_epoch: + cur_epoch = est_epoch + scheduler.step_epoch(cur_epoch) # start from 1 + logging.info(f"Estimated epoch: {cur_epoch}") + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + if params.use_shar: + cuts = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + shard_origin = [str(c.shard_origin).split("/")[2] for c in cuts] + unique_origin = set(shard_origin) + for ori in shard_origin: + if ori in shard_count: + shard_count[ori] += 1 + else: + shard_count[ori] = 1 + count = {orig: 0 for orig in unique_origin} + for sh in shard_origin: + count[sh] += 1 + + if batch_idx % 200 == 1: + logging.info(count) + logging.info(f"All shards source by far: {shard_count}") + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + # if not saved_bad_model: + # save_bad_model(suffix="-first-warning") + # saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + model.train() + if params.use_shar and params.batch_idx_train > params.max_iters: + return + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + local_rank = setup_distributed() + else: + local_rank = rank + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}") + + sp = None + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + num_param_prediction_head = sum([p.numel() for p in model.codebook_loss_net.parameters()]) + logging.info(f"Number of encoder parameters: {num_param - num_param_prediction_head}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # Setting the encoder lr scale + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + parameters = get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ) + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + # When using zip sampler to combine speech and audio data + # we distribute the max-duration to each sampler according to their + # total duration + train_cuts = {} + train_cuts_duration = [] + + # NOTE: We combine all the ASR data together, and use one sampler. + # We use CutSet.mux to sample with weight, the weight is the number + # of training samples (NOT the total duration)! + asr_training_cuts = [] + asr_training_cuts_lens = [] + asr_training_cuts_duration = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # n_cuts + librispeech_cuts_duration = 100 + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 + librispeech_cuts_duration = 960 + librispeech_cuts = librispeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=0 + asr_training_cuts.append(librispeech_cuts) + asr_training_cuts_lens.append(librispeech_cuts_len * params.repeat_librispeech) + asr_training_cuts_duration.append(librispeech_cuts_duration * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "xs": 9389, + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + gigaspeech_cuts_duration = { + "xs": 10, + "s": 250, # 250 hrs + "m": 1000, # 1000 hrs + "l": 2500, # 2500 hrs + "xl": 10000 # 10000 hrs + } + gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(gigaspeech_cuts) + asr_training_cuts_lens.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + asr_training_cuts_duration.append(gigaspeech_cuts_duration[params.gigaspeech_subset]) + + if params.use_libriheavy: + libriheavy_cuts = librispeech.libriheavy_train_cuts() + libriheavy_cuts_len = { + "small": 122512 * 0.9, # 122512 + "medium": 996017, # 1093040, fewer after filtering + "large": 10093746, + } + libriheavy_cuts_duration = { + "small": 466, + "medium": 4148, + "large": 42074, + } + libriheavy_cuts = libriheavy_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(libriheavy_cuts) + asr_training_cuts_lens.append(libriheavy_cuts_len[params.libriheavy_subset]) + asr_training_cuts_duration.append(libriheavy_cuts_duration[params.libriheavy_subset]) + + if params.use_voxpopuli: + voxpopuli_cuts = librispeech.voxpopuli_unlabelled_cuts() + voxpopuli_cuts = voxpopuli_cuts.map(partial(_add_task_id, 1)) + # vox en unlabelled: 24151 hrs, 3059813 cuts + asr_training_cuts.append(voxpopuli_cuts) + asr_training_cuts_lens.append(3059813) + asr_training_cuts_duration.append(24151) + + if params.use_mls: + mls_cuts = librispeech.mls_cuts() + mls_cuts = mls_cuts.map(partial(_add_task_id, 1)) + # mls cuts: 10801 hrs, 2619190 cuts + asr_training_cuts.append(mls_cuts) + asr_training_cuts_lens.append(2619190) + asr_training_cuts_duration.append(10801) + + if params.use_extra_english_dataset: + englishs_cuts, english_cut_durations, english_cuts_len = librispeech.multi_english_cuts() + englishs_cuts = englishs_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(englishs_cuts) + asr_training_cuts_lens.append(english_cuts_len) + asr_training_cuts_duration.append(english_cut_durations) + + if params.use_wenetspeech: + wenetspeech_cuts = librispeech.wenetspeech_train_cuts() + wenetspeech_cuts_len = { + "S": 151600, + "M": 1514500, + "L": 13306651, # TODO: update this number + } + wenetspeech_cuts_duration = { + "S": 100, + "M": 1000, + "L": 9700, + } + wenetspeech_cuts = wenetspeech_cuts.map(partial(_add_task_id, 1)) # ASR task ID=1 + asr_training_cuts.append(wenetspeech_cuts) + asr_training_cuts_lens.append(wenetspeech_cuts_len[params.wenetspeech_subset]) + asr_training_cuts_duration.append(wenetspeech_cuts_duration[params.wenetspeech_subset]) + + if params.use_extra_chinese_dataset: + chineses_cuts, chinese_cut_durations, chinese_cuts_len = librispeech.multi_chinese_cuts() + chineses_cuts = chineses_cuts.map(partial(_add_task_id, 1)) + asr_training_cuts.append(chineses_cuts) + asr_training_cuts_lens.append(chinese_cuts_len) + asr_training_cuts_duration.append(chinese_cut_durations) + + if params.use_emotion_dataset: + other_emotion_cuts = librispeech.multi_emotion_cuts() + msp_podcast_cuts = librispeech.msp_podcast_train_cust() + emotion_cuts = CutSet.mux( + *[other_emotion_cuts, msp_podcast_cuts], + weights=[134, 52], + stop_early=False, + ) + emotion_cuts = emotion_cuts.map(partial(_add_task_id, 1)) # for now we treat ER cuts as part of ASR cuts + asr_training_cuts.append(emotion_cuts) + asr_training_cuts_lens.append(130297 * params.repeat_emo) # 46267 + 84030 + asr_training_cuts_duration.append(186 * params.repeat_emo) # 52 + 134 + + # combine the asr data into a BIG cut + if len(asr_training_cuts) >= 1: + # assert len(asr_training_cuts) >= 1, len(asr_training_cuts) + logging.info(f"ASR cuts: {asr_training_cuts}") + logging.info(f"ASR cuts length: {asr_training_cuts_lens}") + logging.info(f"ASR cuts duration: {asr_training_cuts_duration}") + if len(asr_training_cuts) > 1: + asr_training_cuts = CutSet.mux( + *asr_training_cuts, + weights=asr_training_cuts_lens, + stop_early=False, + ) + else: + asr_training_cuts = asr_training_cuts[0] + + train_cuts["cuts_asr"] = asr_training_cuts + train_cuts_duration.append(sum(asr_training_cuts_duration)) + + # audio data + audio_training_cuts = [] + audio_training_cuts_lens = [] + audio_training_cuts_duration = [] + if params.use_audioset: + logging.info(f"Getting audioset cuts") + if params.repeat_audioset > 1 and not params.use_shar: + audioset_cuts = librispeech.audioset_cuts().repeat( + times=params.repeat_audioset, + preserve_id=False + ) + else: + audioset_cuts = librispeech.audioset_cuts() + + audioset_cuts_lens = { + "balanced": 21155, + "full": 1904746, + } + audioset_cuts_duration = { + "balanced": 50, + "full": params.at_num_samples * 10 / 3600 if params.at_weighted_sampler else 5244, + } + audioset_cuts = audioset_cuts.map(partial(_add_task_id, 2)) + num_audio_cuts = audioset_cuts_lens[params.audioset_subset] * params.repeat_audioset + audio_training_cuts.append(audioset_cuts) + audio_training_cuts_lens.append(num_audio_cuts) + audio_training_cuts_duration.append(audioset_cuts_duration[params.audioset_subset] * params.repeat_audioset) + + if params.use_music4all: + music4all_cuts = librispeech.music4all_cuts() # 910 hrs, 109269 cuts + music4all_cuts = music4all_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(music4all_cuts) + audio_training_cuts_lens.append(109269 * params.repeat_music4all) + audio_training_cuts_duration.append(910 * params.repeat_music4all) + + if params.use_vggsound: + vggsound_cuts = librispeech.vggsound_train_cuts() # 427 hrs, 154142 cuts + vggsound_cuts = vggsound_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(vggsound_cuts) + audio_training_cuts_lens.append(154142 * params.repeat_vggsound) + audio_training_cuts_duration.append(427 * params.repeat_vggsound) + + if params.use_bbceffect: + # split into 10s + bbceffect_cuts = librispeech.bbc_soundeffect_train_cuts() # 430 hrs, 160905 cuts + bbceffect_cuts = bbceffect_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(bbceffect_cuts) + audio_training_cuts_lens.append(160905) + audio_training_cuts_duration.append(430) + + if params.use_freesound: + # split into 10s + freesound_cuts = librispeech.freesound_train_cuts() # 2516 hrs, 1073093 cuts + freesound_cuts = freesound_cuts.map(partial(_add_task_id, 2)) + audio_training_cuts.append(freesound_cuts) + audio_training_cuts_lens.append(1073093) + audio_training_cuts_duration.append(2516) + + + # combine the audio datasets + if len(audio_training_cuts) >= 1: + logging.info(f"audio cuts: {audio_training_cuts}") + logging.info(f"audio cuts length: {audio_training_cuts_lens}") + logging.info(f"audio cuts duration: {audio_training_cuts_duration}") + if len(audio_training_cuts) > 1: + audio_training_cuts = CutSet.mux( + *audio_training_cuts, + weights=audio_training_cuts_lens, + stop_early=False, + ) + else: + audio_training_cuts = audio_training_cuts[0] + + train_cuts["cuts_audio"] = audio_training_cuts + train_cuts_duration.append(sum(audio_training_cuts_duration)) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(train_cuts_duration) + params.train_duration = sum(train_cuts_duration) + + def remove_short_and_long_utt(c: Cut): + if c.duration < 0.98 or c.duration > 31: + return False + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + # Combine the ASR and audio data together + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + train_cuts_lens = [sum(asr_training_cuts_lens), num_audio_cuts] + logging.info(f"Training cuts lens: {train_cuts_lens}") + train_cuts = CutSet.mux( + *train_cuts, + weights=train_cuts_lens, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + # NOTE: when using Shar, the sampler shouldn't have state + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sampling_weight=train_cuts_duration, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech or params.use_libriheavy: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + ls_valid_cuts = ls_valid_cuts.map(partial(_add_task_id, 1)) + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + giga_dev_cuts = giga_dev_cuts.map(partial(_add_task_id, 1)) + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + if params.use_wenetspeech: + wenet_dev_cuts = librispeech.wenetspeech_valid_cuts() + wenet_dev_cuts = wenet_dev_cuts.map(partial(_add_task_id, 1)) + asr_wenet_valid_dl = librispeech.valid_dataloaders(wenet_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_wenet") + valid_dls.append(asr_wenet_valid_dl) + + if params.use_emotion_dataset: + msp_podcast_dev_cuts = librispeech.msp_podcast_dev_cust() + msp_podcast_dev_cuts = msp_podcast_dev_cuts.map(partial(_add_task_id, 1)) + er_msp_dev_dl = librispeech.valid_dataloaders(msp_podcast_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ER_msp_podcast") + valid_dls.append(er_msp_dev_dl) + + if params.use_audioset: + as_eval_cuts = librispeech.audioset_eval_cuts() + as_eval_cuts = as_eval_cuts.map(partial(_add_task_id, 2)) + at_valid_dl = librispeech.valid_dataloaders(as_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_as") + valid_dls.append(at_valid_dl) + + if params.use_vggsound: + vggsound_eval_cuts = librispeech.vggsound_test_cuts() + vggsound_eval_cuts = vggsound_eval_cuts.map(partial(_add_task_id, 2)) + vggsound_valid_dl = librispeech.valid_dataloaders(vggsound_eval_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_vggsound") + valid_dls.append(vggsound_valid_dl) + + if params.use_bbceffect: + bbc_test_cuts = librispeech.bbc_soundeffect_test_cuts() + bbc_test_cuts = bbc_test_cuts.map(partial(_add_task_id, 2)) + bbc_test_dl = librispeech.valid_dataloaders(bbc_test_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_bbc") + valid_dls.append(bbc_test_dl) + + if params.use_freesound: + freesound_test_cuts = librispeech.freesound_test_cuts() + freesound_test_cuts = freesound_test_cuts.map(partial(_add_task_id, 2)) + freesound_test_dl = librispeech.valid_dataloaders(freesound_test_cuts, world_size=world_size, rank=rank,) + valid_sets.append("AT_freesound") + valid_dls.append(freesound_test_dl) + + logging.info(f"Validation sets: {valid_sets}") + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + if not params.use_shar: + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.batch_idx_train > params.max_iters: + logging.info(f"Already reached the maximum iterations: {params.max_iters}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler if not params.use_shar else None, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_mvq_kd.py b/egs/emilia/CLAP/spear/train_mvq_kd.py new file mode 100644 index 0000000000..744c3da0b8 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_mvq_kd.py @@ -0,0 +1,1360 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union, List + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_rank, + get_world_size, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # MVQ distillation related + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + model = AsrKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + mvq_tokens = batch["cb_indexes"].to(device) + + # texts = batch["supervisions"]["text"] + # y = sp.encode(texts, out_type=int) + # y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + mvq_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + ) + + loss = 0.0 + + loss += mvq_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["mvq_loss"] = mvq_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # sp is None + sp = None + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + logging.info(f"Rank: {rank}") + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + train_cuts = {} + data_sampling_weight = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # 100 # the duration + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 # 960 hrs + if params.repeat_librispeech > 1: + librispeech_cuts = librispeech_cuts.repeat(params.repeat_librispeech) + # librispeech_cuts = librispeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + train_cuts["cuts_asr_libri"] = librispeech_cuts + data_sampling_weight.append(librispeech_cuts_len * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + # gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=1 + train_cuts["cuts_asr_giga"] = gigaspeech_cuts + data_sampling_weight.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(data_sampling_weight) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + logging.info(f"Training cuts: {data_sampling_weight}") + train_cuts = CutSet.mux( + *train_cuts, + weights=data_sampling_weight, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict, sampling_weight=data_sampling_weight, world_size=world_size, rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/train_mvq_kd_multinode.py b/egs/emilia/CLAP/spear/train_mvq_kd_multinode.py new file mode 100644 index 0000000000..0a5bede265 --- /dev/null +++ b/egs/emilia/CLAP/spear/train_mvq_kd_multinode.py @@ -0,0 +1,1375 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import os +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union, List + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from mtl_datamodule import MultiTaskDataModule +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrKDModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # MVQ distillation related + parser.add_argument( + "--distillation-layer", + type=int, + default=-1, + ) + + parser.add_argument( + "--distillation-delta", + type=int, + default=0, + ) + + parser.add_argument( + "--teacher-frame-ratio", + type=int, + default=2, + help="The frame rate ratio between teacher and student" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--stop-early", + type=str2bool, + default=True, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + model = AsrKDModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_codebooks=params.num_codebooks, + distillation_layer=params.distillation_layer, + distillation_delta=params.distillation_delta, + teacher_frame_ratio=params.teacher_frame_ratio, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + mvq_tokens = batch["cb_indexes"].to(device) + + # texts = batch["supervisions"]["text"] + # y = sp.encode(texts, out_type=int) + # y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + mvq_loss = model( + x=feature, + x_lens=feature_lens, + codebook_indexes=mvq_tokens, + ) + + loss = 0.0 + + loss += mvq_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["mvq_loss"] = mvq_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_sets: List[str], + valid_dls: List[torch.utils.data.DataLoader], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation on {valid_set}: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + if args.use_multi_node: + print(f"Using multi node!") + local_rank = get_local_rank() + else: + local_rank = rank + print( + f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}" + ) + + if world_size > 1: + setup_dist(rank, world_size, params.master_port, params.use_multi_node) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # logging.info(f"Device: {device}") + logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # sp is None + sp = None + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.cuda() + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + logging.info(f"Rank: {local_rank}") + + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = MultiTaskDataModule(args) + + train_cuts = {} + data_sampling_weight = [] + if params.use_librispeech: + if not params.full_libri: + librispeech_cuts = librispeech.train_clean_100_cuts() + librispeech_cuts_len = 85617 # 100 # the duration + else: + librispeech_cuts = librispeech.train_all_shuf_cuts() + librispeech_cuts_len = 281239 # 960 hrs + if params.repeat_librispeech > 1: + librispeech_cuts = librispeech_cuts.repeat(params.repeat_librispeech) + # librispeech_cuts = librispeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=0 + train_cuts["cuts_asr_libri"] = librispeech_cuts + data_sampling_weight.append(librispeech_cuts_len * params.repeat_librispeech) + + if params.use_gigaspeech: + gigaspeech_cuts = librispeech.gigaspeech_train_cuts() + gigaspeech_cuts_len = { + "s": 210012, # 250 hrs + "m": 859493, # 1000 hrs + "l": 2152879, # 2500 hrs + "xl": 8611516 # 10000 hrs + } + # gigaspeech_cuts = gigaspeech_cuts.map(partial(_add_dummy_embeddings_and_taskIDs, 1)) # ASR task ID=1 + train_cuts["cuts_asr_giga"] = gigaspeech_cuts + data_sampling_weight.append(gigaspeech_cuts_len[params.gigaspeech_subset]) + + assert len(train_cuts) >= 1, "At least one task should be done!" + + logging.info(train_cuts) + logging.info(data_sampling_weight) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + # If we filter the data and use weighted_sampler, the number of cuts + # will be smaller, and won't match the sampling weight + if not params.at_weighted_sampler: + for k, cuts in train_cuts.items(): + train_cuts[k] = cuts.filter(remove_short_and_long_utt) + + if params.bucketing_sampler: + assert params.zip_sampler == False + train_cuts = [item[1] for item in train_cuts.items()] + if len(train_cuts) > 1: + logging.info(f"Using mux to combine data") + logging.info(f"Training cuts: {train_cuts}") + logging.info(f"Training cuts: {data_sampling_weight}") + train_cuts = CutSet.mux( + *train_cuts, + weights=data_sampling_weight, + stop_early=params.stop_early, + ) + else: + train_cuts = train_cuts[0] + assert isinstance(train_cuts, CutSet), type(train_cuts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict, sampling_weight=data_sampling_weight, world_size=world_size, rank=rank, + ) + + valid_sets = [] + valid_dls = [] + + # valid dataloaders + if params.use_librispeech: + ls_valid_cuts = librispeech.dev_clean_cuts() + ls_valid_cuts += librispeech.dev_other_cuts() + asr_ls_valid_dl = librispeech.valid_dataloaders(ls_valid_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_ls") + valid_dls.append(asr_ls_valid_dl) + + if params.use_gigaspeech: + giga_dev_cuts = librispeech.gigaspeech_dev_cuts() + asr_giga_valid_dl = librispeech.valid_dataloaders(giga_dev_cuts, world_size=world_size, rank=rank,) + valid_sets.append("ASR_giga") + valid_dls.append(asr_giga_valid_dl) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_sets=valid_sets, + valid_dls=valid_dls, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + # y = sp.encode(supervisions["text"], out_type=int) + # num_tokens = sum(len(i) for i in y) + # logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MultiTaskDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear/utils.py b/egs/emilia/CLAP/spear/utils.py new file mode 100644 index 0000000000..84db7dd09b --- /dev/null +++ b/egs/emilia/CLAP/spear/utils.py @@ -0,0 +1,376 @@ +import collections +import logging +import os +import re +from typing import List, Tuple + +import torch +import torch.distributed as dist +from lhotse.array import Array, TemporalArray +from torch.utils.tensorboard import SummaryWriter + +from icefall.byte_utils import byte_encode +from icefall.utils import tokenize_by_CJK_char + + +def _normalize_chinese_text(text): + # 去除所有标点符号 + text = re.sub(r"[,。!?、;:“”‘’()《》【】{}·…—~]", "", text) + # 去除汉字之间的空格(确保不影响英文单词) + text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text) + text = text.upper() + return text + + +def normalize_chinese_text(c): + text = c.supervisions[0].text + text = _normalize_chinese_text(text) + c.supervisions[0].text = text + return c + + +def _normalize_english_text(text): + # 只保留字母、数字、空格和单引号,去掉其他标点符号 + text = re.sub(r"[^\w\s']", "", text) + # 转换为大写 + text = text.upper() + return text + + +def normalize_english_text(c): + text = c.supervisions[0].text + text = _normalize_english_text(text) + c.supervisions[0].text = text + return c + + +def remove_non_alphabetic(text: str, strict: bool = True) -> str: + # Recommend to set strict to False + if not strict: + # Note, this also keeps space, single quote(') + text = text.replace("-", " ") + text = text.replace("—", " ") + return re.sub(r"[^a-zA-Z0-9\s']+", "", text) + else: + # only keeps space + return re.sub(r"[^a-zA-Z\s]+", "", text) + + +def map_zh(c): + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + +def upper_only_alpha(c): + text = c.supervisions[0].text + text = remove_non_alphabetic(text.upper(), strict=False) + c.supervisions[0].text = text + return c + + +def add_dummy_text(c): + if c.supervisions[0].text is None: + c.supervisions[ + 0 + ].text = "Dummy text added as a place holder. Please ignore this if possible." + return c + + +def _add_dummy_embeddings_and_taskIDs(task_ID: int, c): + whisper_embedding_dict = { + "array": { + "storage_type": "numpy_hdf5", + "storage_path": "data/dummy_embeddings/dummy_whisper_embedding_1510.h5", + "storage_key": "dummy_whisper_embedding_1510", + "shape": [1510, 1280], + }, + "temporal_dim": 0, + "frame_shift": 0.02, + "start": 0, + } + whisper_dummy_embedding = TemporalArray.from_dict(whisper_embedding_dict) + + whisper_cb_indexes_dict = { + "array": { + "storage_type": "numpy_hdf5", + "storage_path": "data/dummy_embeddings/dummy_whisper_codebook_indexes_1510.h5", + "storage_key": "dummy_whisper_codebook_indexes_1510", + "shape": [1510, 16], + }, + "temporal_dim": 0, + "frame_shift": 0.02, + "start": 0, + } + whisper_cb_indexes = TemporalArray.from_dict(whisper_cb_indexes_dict) + + beats_embedding_dict = { + "storage_type": "numpy_hdf5", + "storage_path": "data/dummy_embeddings/dummy_beats_embedding.h5", + "storage_key": "dummy_beats_embedding", + "shape": [527], + } + beats_dummy_embedding = Array.from_dict(beats_embedding_dict) + + ecapa_embedding_dict = { + "storage_type": "numpy_hdf5", + "storage_path": "dummy_ecapa_embedding.h5", + "storage_key": "dummy_ecapa_embedding", + "shape": [1, 192], + } + ecapa_dummy_embedding = Array.from_dict(ecapa_embedding_dict) + + mert_embedding_dict = { + "array": { + "storage_type": "numpy_hdf5", + "storage_path": "data/dummy_embeddings/dummy_mert_embedding_2260.h5", + "storage_key": "dummy_mert_embedding", + "shape": [2260, 1024], + }, + "temporal_dim": 0, + "frame_shift": 0.013333333333333334, + "start": 0, + } + mert_dummy_embedding = TemporalArray.from_dict(mert_embedding_dict) + + def add_embeddings(c): + # if not c.has_custom("whisper_embedding"): + # c.whisper_embedding = whisper_dummy_embedding + if not c.has_custom("codebook_indexes"): + c.codebook_indexes = whisper_cb_indexes + + # if not c.has_custom("ecapa_embedding"): + # c.ecapa_embedding = ecapa_dummy_embedding + if not c.has_custom("beats_embedding"): + c.beats_embedding = beats_dummy_embedding + # if not c.supervisions[0].has_custom("audio_event"): + # c.supervisions[0].audio_event = "0" + if c.supervisions[0].text is None: + c.supervisions[ + 0 + ].text = ( + "Dummy text added as a place holder. Please ignore this if possible." + ) + if task_ID is not None: + c.task_id = task_ID + return c + + c = add_embeddings(c) + return c + + +def _add_task_id(task_id, c): + c.task_id = task_id + return c + + +def _add_language_id(lid, c): + c.language_id = lid + return c + + + +def _save_checkpoint_with_global_batch_idx( + params, + model, + optimizer=None, + sampler=None, + scheduler=None, + scaler=None, + model_avg=None, + rank: int = 0, +): + # only active when rank==0 + if rank != 0: + return + + if isinstance(model, DDP): + model = model.module + else: + model = model + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "scheduler": scheduler.state_dict() if scheduler is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + output_path = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" + + if params.save_with_client: + output_path = "brainllm:s3://yangxiaoyu/" + str(output_path) + logging.info(f"Saving checkpoint to {output_path}") + with io.BytesIO() as f: + torch.save(checkpoint, f) + f.seek(0) + params.client.put(output_path, f) + logging.info(f"Finish saving checkpoint to {output_path}") + else: + logging.info(f"Saving checkpoint to {output_path}") + torch.save(checkpoint, output_path) + + +def _save_checkpoint( + filename, + model, + model_avg=None, + params=None, + optimizer=None, + scheduler=None, + scaler=None, + sampler=None, + rank: int = 0, +): + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "scheduler": scheduler.state_dict() if scheduler is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + if "s3://" in filename: + with io.BytesIO() as f: + torch.save(checkpoint, f) + f.seek(0) + params.client.put(filename, f) + logging.info(f"Finish saving checkpoint to {filename}") + else: + torch.save(checkpoint, filename) + + +class MetricsTracker(collections.defaultdict): + def __init__(self, normalize: bool = True): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + self.normalize = normalize + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + if v - v == 0: + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans_frames = "" + ans_utterances = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + if "utt_" not in k: + ans_frames += str(k) + "=" + str(norm_value) + ", " + else: + ans_utterances += str(k) + "=" + str(norm_value) + if k == "utt_duration": + ans_utterances += " frames, " + elif k == "utt_pad_proportion": + ans_utterances += ", " + else: + raise ValueError(f"Unexpected key: {k}") + frames = "%.2f" % self["frames"] + ans_frames += "over " + str(frames) + " frames. " + if ans_utterances != "": + utterances = "%.2f" % self["utterances"] + ans_utterances += "over " + str(utterances) + " utterances." + + return ans_frames + ans_utterances + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self["frames"] if "frames" in self else 1 + num_utterances = self["utterances"] if "utterances" in self else 1 + ans = [] + for k, v in self.items(): + if k == "frames" or k == "utterances": + continue + if not self.normalize: + ans.append((k, float(v))) + continue + if ("audio_tagging" in k) or ("speaker_verification" in k): + norm_value = float(v) / num_utterances + else: + norm_value = ( + float(v) / num_frames + if "utt_" not in k + else float(v) / num_utterances + ) + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + +if __name__ == "__main__": + text = "你好 , 这是 一个 测试 句子 !Hello 希望 这段 代码 能正常 工作 。" + normalized_text = normalize_chinese_text(text) + print(normalized_text) + + text = "Hello, world! It's a great day to learn NLP." + normalized_text = normalize_english_text(text) + print(normalized_text) diff --git a/egs/emilia/CLAP/spear/zipformer.py b/egs/emilia/CLAP/spear/zipformer.py new file mode 120000 index 0000000000..23011dda71 --- /dev/null +++ b/egs/emilia/CLAP/spear/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear/zipformer2.py b/egs/emilia/CLAP/spear/zipformer2.py new file mode 100644 index 0000000000..c43b7ae38b --- /dev/null +++ b/egs/emilia/CLAP/spear/zipformer2.py @@ -0,0 +1,2447 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + if output_downsampling_factor >= 2: + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + else: + self.downsample_output = None + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + return_middle_out: bool = False, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + + if self.output_downsampling_factor >= 2: + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + else: + lengths = x_lens + if return_middle_out: + return x, lengths, outputs + else: + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=-1) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) diff --git a/egs/emilia/CLAP/spear/zipformer_bf16.py b/egs/emilia/CLAP/spear/zipformer_bf16.py new file mode 100644 index 0000000000..acfd2dffe0 --- /dev/null +++ b/egs/emilia/CLAP/spear/zipformer_bf16.py @@ -0,0 +1,2456 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union +import logging +import torch +import random +from encoder_interface import EncoderInterface +from scaling_bf16 import ( + Balancer, + BiasNorm, + Dropout2, + ChunkCausalDepthwiseConv1d, + ActivationDropoutAndLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + penalize_abs_values_gt, + softmax, + ScheduledFloat, + FloatLike, + limit_param_value, + convert_num_channels, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + return_middle_out: bool = False, + freezing_layer_idx: List[int] = [], + forward_first_n: int = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + return_middle_out: + Return the layer-wise output + freezing_layer_idx: + A list of integers, indicating which layers in the encoder should + be frozen. If not initialized, it will be an empty list. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + layerwise_outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + freeze_current_layer = i in freezing_layer_idx # should return False for an empty list + if freeze_current_layer: + module.eval() + + with torch.set_grad_enabled(not freeze_current_layer): + x, layer_results = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) + outputs.append(x) + layerwise_outputs.append(layer_results) + + if forward_first_n is not None and i == forward_first_n: + return x, x_lens, layerwise_outputs + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + if return_middle_out: + return x, lengths, layerwise_outputs + else: + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + outputs = [] + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + outputs.append(output) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output, outputs + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src, layer_results = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + # upsample layerwise results + # layer_results = [self.upsample(res)[: src_orig.shape[0]] for res in layer_results] + + return self.out_combiner(src_orig, src), layer_results + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(x.dtype).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=-1) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) diff --git a/egs/emilia/CLAP/spear_roberta/asr_datamodule.py b/egs/emilia/CLAP/spear_roberta/asr_datamodule.py new file mode 100644 index 0000000000..143b2ad10d --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/asr_datamodule.py @@ -0,0 +1,531 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + combine, + load_manifest, + load_manifest_lazy, +) +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class DataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders. + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=16, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--drop-features", + type=str2bool, + default=False, + help="If drop the pre-computed features", + ) + + group.add_argument( + "--return-audio", + type=str2bool, + default=False, + help="Return audio while collating batch", + ) + + group.add_argument( + "--num-mel-bins", + type=int, + default=128, + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: int = 1, + rank: int = 0, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info( + "Using DynamicBucketingSampler with strict FixedBucketBatchSizeConstraint." + ) + constraint = FixedBucketBatchSizeConstraint( + max_seq_len_buckets=self.args.max_seq_len_buckets, + batch_sizes=self.args.fixed_batch_sizes, + ) + train_sampler = DynamicBucketingSampler( + cuts_train, + constraint=constraint, + shuffle=True, + drop_last=True, + duration_bins=self.args.duration_bins, + buffer_size=self.args.num_buckets * 5000, + sync_buckets=True, + concurrent=False, + world_size=world_size, + rank=rank, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=4, + persistent_workers=True, + pin_memory=True, + prefetch_factor=16, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: int = 1, + rank: int = 0, + ) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=4, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins)) + ) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=4, + ) + return test_dl + + def estimate_duration_bins( + self, + cuts: CutSet, + world_size: int = 1, + rank: int = 0, + ) -> List[float]: + logging.info("Estimating duration bins for FixedBucketBatchSizeConstraint") + + dummy_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=True, + drop_last=True, + buffer_size=self.args.num_buckets * 5000, + sync_buckets=True, + concurrent=False, + world_size=world_size, + rank=rank, + ) + duration_bins = dummy_sampler.duration_bins + del dummy_sampler + return duration_bins + + @lru_cache() + def emilia_en_cuts(self) -> CutSet: + logging.info("About to get Emilia EN tars") + filenames = glob.glob("./download/Emilia/EN/*.tar") + logging.info(f"Loading Emilia {len(filenames)} tars in lazy mode") + return CutSet.from_webdataset( + filenames, + shuffle_shards=True, + split_by_worker=False, + split_by_node=False, + ) + + @lru_cache() + def paraspeechcaps_train_base_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps train-base shuffled cuts") + return load_manifest_lazy( + self.args.manifest_dir + / "paraspeechcaps_cuts_train_base_shuf-selected.jsonl.gz" + ) + + @lru_cache() + def paraspeechcaps_dev_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps dev cuts") + splits = ["voxceleb", "expresso", "ears"] + return combine( + load_manifest_lazy( + self.args.manifest_dir + / f"paraspeechcaps_cuts_dev-{s}-selected.jsonl.gz" + ) + for s in splits + ) + + @lru_cache() + def paraspeechcaps_test_cuts(self) -> CutSet: + logging.info("About to get paraspeechcaps test cuts") + splits = ["voxceleb", "expresso", "ears"] + return combine( + load_manifest_lazy( + self.args.manifest_dir + / f"paraspeechcaps_cuts_test-{s}-selected.jsonl.gz" + ) + for s in splits + ) + + @lru_cache() + def iemocap_cuts(self) -> CutSet: + logging.info("About to get iemocap cuts") + return load_manifest_lazy(self.args.manifest_dir / "iemocap_cuts_all.jsonl.gz") + + @lru_cache() + def ravdess_cuts(self) -> CutSet: + logging.info("About to get ravdess cuts") + return load_manifest_lazy(self.args.manifest_dir / "ravdess_cuts_all.jsonl.gz") + + @lru_cache() + def cremad_cuts(self) -> CutSet: + logging.info("About to get crema-d cuts") + return load_manifest_lazy(self.args.manifest_dir / "cremad_cuts_test.jsonl.gz") diff --git a/egs/emilia/CLAP/spear_roberta/attribute_perturbation.py b/egs/emilia/CLAP/spear_roberta/attribute_perturbation.py new file mode 100644 index 0000000000..e11fc072b4 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/attribute_perturbation.py @@ -0,0 +1,802 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import re +from typing import Dict, List + + +def match_case(word, replacement): + if word.isupper(): + return replacement.upper() + elif word[0].isupper(): + return replacement.title() + else: + return replacement.lower() + + +GENDER = ["female", "male"] +GENDER_PAIRS = [ + ("female", "male"), + ("woman", "man"), + ("girl", "boy"), + ("feminine", "masculine"), + ("she", "he"), + ("her", "his"), + ("herself", "himself"), +] + + +def perturb_gender_in_text(text: str, gender: str) -> dict: + replaced_flag = False + + # 1. 确定替换方向 + # 如果当前标签是 female,我们要把 text 里的 female 词汇替换成 male (构造负例) + # GENDER_PAIRS 的结构是 (female_word, male_word) + # source_idx = 0 (female), target_idx = 1 (male) + if gender == "female": + source_idx, target_idx = 0, 1 + elif gender == "male": + source_idx, target_idx = 1, 0 + else: + raise ValueError(f"Unknown gender: {gender}") + + # 2. 依次遍历关键词进行匹配 + for pair in GENDER_PAIRS: + src_word = pair[source_idx] + tgt_word = pair[target_idx] + + # 构造正则表达式 + # Group 1: ("[^"]*") -> 匹配双引号括起来的任意内容 (保护区) + # Group 2: (\bkeyword\b) -> 匹配目标关键词,\b 确保词边界 (匹配区) + # re.IGNORECASE -> 忽略大小写 + pattern = r'("[^"]*")|\b(' + re.escape(src_word) + r")\b" + + def replace_callback(match): + nonlocal replaced_flag + + # 如果匹配到了 Group 1 (双引号内容),原样返回,不动它 + if match.group(1): + return match.group(1) + + # 如果匹配到了 Group 2 (关键词),进行替换 + if match.group(2): + original_word = match.group(2) + replaced_flag = True + return match_case(original_word, tgt_word) + + # 进行替换 + text = re.sub(pattern, replace_callback, text, flags=re.IGNORECASE) + + return {"flag": replaced_flag, "text": text} + + +SPEAKING_RATE = ["fast speed", "slow speed"] # "measured speed" +SPEAKING_RATE_PAIRS = [ + ("fast-paced", "slow-paced"), + ("fast", "slow"), + ("quick", "slow"), + ("rapid", "slow"), + ("rushed", "deliberate"), + ("hurried", "unhurried"), + ("quickly", "slowly"), + ("rapidly", "slowly"), + ("fast", "measured"), +] + + +def perturb_speaking_rate_in_text(text: str, speaking_rate: str) -> dict: + replaced_flag = False + + # 1. 确定替换方向 + # source_idx: 我们要在文本中搜索的词 + # target_idx: 我们要替换成的词 + if speaking_rate == "fast speed": + source_idx, target_idx = 0, 1 + elif speaking_rate == "slow speed": + source_idx, target_idx = 1, 0 + else: # 如果标签不合法或者是 measured (不考虑),直接返回 + return {"flag": False, "text": text} + + # 2. 依次遍历关键词进行匹配 + for pair in SPEAKING_RATE_PAIRS: + src_word = pair[source_idx] + tgt_word = pair[target_idx] + + # 构造正则表达式 + # Group 1: ("[^"]*") -> 匹配双引号括起来的任意内容 (保护区) + # Group 2: (\bkeyword\b) -> 匹配目标关键词,\b 确保词边界 (匹配区) + # re.IGNORECASE -> 忽略大小写 + # re.escape(src_word) 用于处理像 "fast-paced" 中间的连字符,防止被正则误读 + pattern = r'("[^"]*")|\b(' + re.escape(src_word) + r")\b" + + def replace_callback(match): + nonlocal replaced_flag + + # 如果匹配到了 Group 1 (双引号内容),原样返回,不动它 + if match.group(1): + return match.group(1) + + # 如果匹配到了 Group 2 (关键词),进行替换 + if match.group(2): + original_word = match.group(2) + replaced_flag = True + return match_case(original_word, tgt_word) + + # 进行替换 + text = re.sub(pattern, replace_callback, text, flags=re.IGNORECASE) + + return {"flag": replaced_flag, "text": text} + + +PITCH = ["high-pitched", "low-pitched"] # "medium-pitched" +PITCH_PAIRS = [ + ("high-pitched", "low-pitched"), + ("higher", "lower"), + ("high", "low"), + ("raising", "lowering"), + ("rises", "falls"), + ("rising", "falling"), + ("raised", "lowered"), + ("upward", "downward"), +] + + +def perturb_pitch_in_text(text: str, pitch: str) -> dict: + replaced_flag = False + + # 1. 确定替换方向 + if pitch == "high-pitched": + source_idx, target_idx = 0, 1 + elif pitch == "low-pitched": + source_idx, target_idx = 1, 0 + else: # 如果标签不合法或者是 medium-pitched (不考虑),直接返回 + return {"flag": False, "text": text} + + # 2. 依次遍历关键词进行匹配 + for pair in PITCH_PAIRS: + src_word = pair[source_idx] + tgt_word = pair[target_idx] + + # 构造正则表达式 + # Group 1: ("[^"]*") -> 匹配双引号括起来的任意内容 (保护区) + # Group 2: (\bkeyword\b) -> 匹配目标关键词,\b 确保词边界 (匹配区) + # re.IGNORECASE -> 忽略大小写 + # re.escape(src_word) 用于处理像 "fast-paced" 中间的连字符,防止被正则误读 + pattern = r'("[^"]*")|\b(' + re.escape(src_word) + r")\b" + + def replace_callback(match): + nonlocal replaced_flag + + # 如果匹配到了 Group 1 (双引号内容),原样返回,不动它 + if match.group(1): + return match.group(1) + + # 如果匹配到了 Group 2 (关键词),进行替换 + if match.group(2): + original_word = match.group(2) + replaced_flag = True + return match_case(original_word, tgt_word) + + # 进行替换 + text = re.sub(pattern, replace_callback, text, flags=re.IGNORECASE) + + return {"flag": replaced_flag, "text": text} + + +ACCENT = [ + "american", + "argentine", + "australian", + "belgian", + "brazilian", + "british", + "british-american", + "british-guyanese", + "brooklyn/new york", + "canadian", + "cantonese", + "chilean", + "chinese", + "colombian", + "colombian-american", + "croatian", + "czech", + "dari", + "dominican", + "dutch", + "english", + "filipino", + "finnish", + "french", + "german", + "hungarian", + "indian", + "irish", + "italian", + "jamaican", + "japanese", + "jordanian", + "mandarin", + "mexican", + "new zealand", + "nigerian", + "northern irish", + "norwegian", + "paraguayan", + "polish", + "portuguese", + "romanian", + "russian", + "scottish", + "serbian", + "slovenian", + "southern american", + "spanish", + "swedish", + "swiss", + "turkish", + "ukrainian", + "welsh", +] +ACCENT_GROUPS: Dict[str, List[str]] = { + # 【英语核心圈】:母语为英语,区别主要在元音变化和R音 + "english_native": [ + "american", + "british", + "australian", + "canadian", + "new zealand", + "irish", + "scottish", + "welsh", + "southern american", + "brooklyn/new york", + "northern irish", + "british-american", + "english", + ], + # 【英语 L2 / 独特韵律】:非母语英语,或有强烈地域韵律特征的英语 + "english_l2_distinct": [ + "indian", + "nigerian", + "jamaican", + "filipino", + "british-guyanese", + "colombian-american", # 虽有American后缀,但口音特征往往带有明显的非母语韵律 + ], + # 【拉丁语族】:元音清晰、音节速率快、重音模式相似 + "romance_latin": [ + "spanish", + "mexican", + "colombian", + "argentine", + "chilean", + "paraguayan", + "portuguese", + "brazilian", + "italian", + "french", + "romanian", + "dominican", + ], + # 【日耳曼与斯拉夫】:辅音丛多、语调相对平直或有特定重音起伏 + "germanic_slavic": [ + "german", + "dutch", + "swedish", + "norwegian", + "swiss", + "belgian", + "russian", + "ukrainian", + "polish", + "czech", + "croatian", + "serbian", + "slovenian", + "hungarian", + "finnish", + ], + # 【亚洲/声调语言】:受声调或高低音重音(Pitch Accent)影响明显的口音 + "asian_tonal": [ + "mandarin", + "cantonese", + "chinese", + "japanese", + ], + # 【中东/其他】:喉音特征、独特的元音发音 + "middle_eastern_other": ["turkish", "dari", "jordanian"], +} +ACCENT2GROUP = { + accent: group_name + for group_name, accents in ACCENT_GROUPS.items() + for accent in accents +} + + +def sample_negative_accent( + accent: str, + p_intra_group: float = 0.20, +) -> str: + # 1. 确定锚点所在的组 + src_group_name = ACCENT2GROUP.get(accent) + if not src_group_name: + raise ValueError(f"Accent '{accent}' not found in any group.") + + # 2. 决定采样策略 (Intra vs Inter) + # 只有当组内成员大于1个时,才有可能进行组内采样 + group_members = ACCENT_GROUPS[src_group_name] + can_do_intra = len(group_members) > 1 + + is_intra_sample = random.random() < p_intra_group + + target_group_name = "" + negative_accent = "" + difficulty = "" + + # === 策略 A: 组内负例 (Hard) === + # 逻辑:只要随机到了intra概率,并且该组有得选,就选组内 + if is_intra_sample and can_do_intra: + difficulty = "hard (intra-group)" + target_group_name = src_group_name + + # 从组内选一个不是自己的 + candidates = [a for a in group_members if a != accent] + negative_accent = random.choice(candidates) + + # === 策略 B: 跨组负例 (Easy/Normal) === + # 逻辑:没随机到intra,或者被迫fallback到inter(因为该组只有一个独苗) + else: + difficulty = "normal (inter-group)" + + # 获取所有组名,移除当前组 + other_groups = [g for g in ACCENT_GROUPS.keys() if g != src_group_name] + + # 随机选一个组 + target_group_name = random.choice(other_groups) + + # 在该组内随机选一个口音 + negative_accent = random.choice(ACCENT_GROUPS[target_group_name]) + + return negative_accent + + +def perturb_accent_in_text(text: str, accent: str) -> dict: + if accent not in ACCENT2GROUP: + return {"flag": False, "text": text} + + # 1. 获取负例目标 (Target) + # 每次调用都会随机采样,可能是 Hard 也可能是 Easy + tgt_accent = sample_negative_accent(accent) + + replaced_flag = False + + # 构造正则表达式 + # Group 1: ("[^"]*") -> 匹配双引号括起来的任意内容 (保护区) + # Group 2: (\bkeyword\b) -> 匹配目标关键词,\b 确保词边界 (匹配区) + # re.IGNORECASE -> 忽略大小写 + # re.escape(accent) 用于连字符,防止被正则误读 + pattern = r'("[^"]*")|\b(' + re.escape(accent) + r")\b" + + def replace_callback(match): + nonlocal replaced_flag + + # 如果匹配到了 Group 1 (双引号内容),原样返回,不动它 + if match.group(1): + return match.group(1) + + # 如果匹配到了 Group 2 (关键词),进行替换 + if match.group(2): + original_word = match.group(2) + replaced_flag = True + return match_case(original_word, tgt_accent) + + # 进行替换 + text = re.sub(pattern, replace_callback, text, flags=re.IGNORECASE) + + return {"flag": replaced_flag, "text": text} + + +INTRINSIC_TAGS = [ + "authoritative", + "booming", + "crisp", + "deep", + "flowing", + "guttural", + "hesitant", + "hushed", + "husky", + "inviting", + "lisp", + "monotone", + "monotonous", + "nasal", + "pitchy", + "punctuated", + "raspy", + "shrill", + "silky", + "slurred", + "smooth", + "soft", + "staccato", + "stammering", + "upbeat", + "vocal-fry", +] +INTRINSIC_PAIRS = [ + # --- 1. 质感/音色 (Texture: Rough vs Smooth) --- + ("raspy", "smooth"), # 粗糙/沙哑 vs 光滑 + ("raspiness", "smoothness"), + ("raspily", "smoothly"), + ("raspy", "silky"), # 沙哑 vs 丝滑 + ("raspiness", "silkiness"), + ("raspily", "silkily"), + ("guttural", "silky"), # 喉音 vs 丝滑 + ("gutturally", "silkily"), + ("vocal-fry", "smoothness"), # 气泡音 vs 光滑 + ("slurred", "crisp"), # 含糊不清 vs 清晰干脆 + ("slurring", "crispness"), + ("slurringly", "crisply"), + ("husky", "crisp"), # 烟嗓 vs 清脆 + ("huskiness", "crispness"), + ("huskily", "crisply"), + ("nasal", "deep"), # 鼻音 vs 深沉 + ("nasality", "depth"), + ("nasally", "deeply"), + # --- 2. 节奏/流利度 (Rhythm: Broken vs Flowing) --- + ("staccato", "flowing"), + ("punctuated", "flowing"), # 强调/顿挫 vs 流畅 + ("punctuation", "flow"), + ("stammering", "flowing"), # 结巴 vs 流畅 + ("stammer", "flow"), + ("stammeringly", "flowingly"), + ("hesitant", "flowing"), # 迟疑 vs 流畅 + ("hesitance", "flow"), + ("hesitantly", "flowingly"), + ("lisp", "crispness"), # 口齿不清 vs 清晰 + ("lisping", "crisp"), + ("lispingly", "crisply"), + # --- 3. 音高 (Pitch: High vs Low) --- + ("shrill", "deep"), # 尖锐 vs 深沉 + ("shrillness", "depth"), + ("shrilly", "deeply"), + ("pitchy", "monotone"), # 音调起伏 vs 单调平直 + ("pitchiness", "monotony"), + ("pitchily", "monotonously"), + # --- 4. 能量/情绪 (Energy: High/Dynamic vs Low/Static) --- + ("booming", "hushed"), # 洪亮 vs 低声 + ("boom", "hush"), + ("booming", "hushedly"), + ("booming", "soft"), # 洪亮 vs 轻柔 + ("boom", "softness"), + ("booming", "softly"), + ("upbeat", "monotonous"), # 欢快 vs 单调乏味 + ("upbeat", "monotone"), # 欢快 vs 单调 + ("authoritative", "hesitant"), # 权威 vs 迟疑 + ("authority", "hesitance"), + ("authoritatively", "hesitantly"), + ("inviting", "authoritative"), # 亲切 vs 权威 + ("invitation", "authority"), + ("invitingly", "authoritatively"), +] +INTRINSIC_TAG_MAP: Dict[str, List[str]] = {} +[ + ( + INTRINSIC_TAG_MAP.setdefault(t1, []).append(t2), + INTRINSIC_TAG_MAP.setdefault(t2, []).append(t1), + ) + for t1, t2 in INTRINSIC_PAIRS +] + + +def perturb_intrinsic_tags(text: str, intrinsic_tags: List[str]) -> dict: + intrinsic_tags_copy = intrinsic_tags[:] + random.shuffle(intrinsic_tags_copy) + flag = False + for tag in intrinsic_tags_copy: + result_dict = perturb_intrinsic_tag_in_text(text, tag) + flag = result_dict["flag"] + text = result_dict["text"] + if flag: + break + return {"flag": flag, "text": text} + + +def perturb_intrinsic_tag_in_text(text: str, intrinsic_tag: str) -> dict: + replaced_flag = False + + # 如果标签不合法,直接返回 + if intrinsic_tag not in INTRINSIC_TAG_MAP: + return {"flag": False, "text": text} + + # 随机选择一个负例目标 + tgt_tag = random.choice(INTRINSIC_TAG_MAP[intrinsic_tag]) + + # 构造正则表达式 + # Group 1: ("[^"]*") -> 匹配双引号括起来的任意内容 (保护区) + # Group 2: (\bkeyword\b) -> 匹配目标关键词,\b 确保词边界 (匹配区) + # re.IGNORECASE -> 忽略大小写 + pattern = r'("[^"]*")|\b(' + re.escape(intrinsic_tag) + r")\b" + + def replace_callback(match): + nonlocal replaced_flag + + # 如果匹配到了 Group 1 (双引号内容),原样返回,不动它 + if match.group(1): + return match.group(1) + + # 如果匹配到了 Group 2 (关键词),进行替换 + if match.group(2): + original_word = match.group(2) + replaced_flag = True + return match_case(original_word, tgt_tag) + + # 进行替换 + text = re.sub(pattern, replace_callback, text, flags=re.IGNORECASE) + + return {"flag": replaced_flag, "text": text} + + +SITUATIONAL_TAGS = [ + "admiring", + "angry", + "animated", + "anxious", + "awed", + "bored", + "calm", + "confused", + "desirous", + "disgusted", + "enthusiastic", + "enunciated", + "guilt", + "happy", + "laughing", + "loud", + "pained", + "passive", + "saddened", + "sarcastic", + "scared", + "singsong", + "sleepy", + "sympathetic", + "whispered", +] +SITUATIONAL_PAIRS = [ + # --- 1. 情绪效价 (Valence: Positive vs Negative) --- + ("happy", "sad"), + ("happy", "saddened"), + ("happiness", "sadness"), + ("happiness", "pain"), + ("happiness", "anger"), + ("happy", "pained"), + ("happy", "angry"), + ("enthusiastic", "bored"), + ("enthusiasm", "boredom"), + ("laughing", "saddened"), + ("laugh", "sadness"), + ("laughter", "sadness"), + ("guilt", "happiness"), + ("guilty", "happy"), + # --- 2. 唤醒度/能量 (Arousal: High vs Low) --- + ("angry", "calm"), + ("anger", "calmness"), + ("scared", "calm"), + ("fear", "calmness"), + ("anxious", "calm"), + ("anxiety", "calmness"), + ("animated", "passive"), + ("animated", "sleepy"), + ("loud", "whispered"), + ("confused", "calm"), + ("confusion", "calmness"), + # --- 3. 态度/互动 (Attitude: Pull vs Push) --- + ("admiring", "disgusted"), + ("admiration", "disgust"), + ("desirous", "disgusted"), + ("desire", "disgust"), + ("awed", "bored"), + ("awe", "boredom"), + ("sympathetic", "sarcastic"), + ("sympathy", "sarcasm"), + ("admiring", "sarcastic"), + ("admiration", "sarcasm"), + # --- 4. 清晰度 --- + ("enunciated", "slurred"), + ("enunciation", "slurring"), + ("singsong", "monotone"), + ("singsongly", "monotonously"), +] +SITUATIONAL_TAG_MAP: Dict[str, List[str]] = {} +[ + ( + SITUATIONAL_TAG_MAP.setdefault(t1, []).append(t2), + SITUATIONAL_TAG_MAP.setdefault(t2, []).append(t1), + ) + for t1, t2 in SITUATIONAL_PAIRS +] + + +def perturb_situational_tags(text: str, situational_tags: List[str]) -> dict: + situational_tags_copy = situational_tags[:] + random.shuffle(situational_tags_copy) + flag = False + for tag in situational_tags_copy: + result_dict = perturb_situational_tag_in_text(text, tag) + flag = result_dict["flag"] + text = result_dict["text"] + if flag: + break + return {"flag": flag, "text": text} + + +def perturb_situational_tag_in_text(text: str, situational_tag: str) -> dict: + replaced_flag = False + + # 如果标签不合法,直接返回 + if situational_tag not in SITUATIONAL_TAG_MAP: + return {"flag": False, "text": text} + + # 随机选择一个负例目标 + tgt_tag = random.choice(SITUATIONAL_TAG_MAP[situational_tag]) + + # 构造正则表达式 + # Group 1: ("[^"]*") -> 匹配双引号括起来的任意内容 (保护区) + # Group 2: (\bkeyword\b) -> 匹配目标关键词,\b 确保词边界 (匹配区) + # re.IGNORECASE -> 忽略大小写 + pattern = r'("[^"]*")|\b(' + re.escape(situational_tag) + r")\b" + + def replace_callback(match): + nonlocal replaced_flag + + # 如果匹配到了 Group 1 (双引号内容),原样返回,不动它 + if match.group(1): + return match.group(1) + + # 如果匹配到了 Group 2 (关键词),进行替换 + if match.group(2): + original_word = match.group(2) + replaced_flag = True + return match_case(original_word, tgt_tag) + + # 进行替换 + text = re.sub(pattern, replace_callback, text, flags=re.IGNORECASE) + + return {"flag": replaced_flag, "text": text} + + +def perturb_one_attribution_in_text( + text: str, + gender: str, + speaking_rate: str, + pitch: str, + accent: str, + intrinsic_tags: List[str], + situational_tags: List[str], +) -> str: + perturbation_functions = [ + perturb_gender_in_text, + perturb_speaking_rate_in_text, + perturb_pitch_in_text, + perturb_accent_in_text, + perturb_intrinsic_tags, + perturb_situational_tags, + ] + attributions = [ + gender, + speaking_rate, + pitch, + accent, + intrinsic_tags, + situational_tags, + ] + candidates = list(zip(perturbation_functions, attributions)) + random.shuffle(candidates) + for func, attr in candidates: + result_dict = func(text, attr) + if result_dict["flag"]: + return result_dict["text"] + + if gender == "male": + result_dict = perturb_gender_in_text(text, "female") + elif gender == "female": + result_dict = perturb_gender_in_text(text, "male") + + if result_dict["flag"]: + return result_dict["text"] + + raise ValueError("No attribution found to perturb.") + + +if __name__ == "__main__": + import difflib + + from lhotse import load_manifest_lazy + + RED = "\033[31m" + GREEN = "\033[32m" + RESET = "\033[0m" + + def color_diff_ori(ori_text: str, norm_text: str) -> str: + sm = difflib.SequenceMatcher(a=ori_text, b=norm_text, autojunk=False) + out = [] + + for tag, i1, i2, j1, j2 in sm.get_opcodes(): + if tag == "equal": + out.append(ori_text[i1:i2]) + + elif tag == "delete": + out.append(f"{RED}{ori_text[i1:i2]}{RESET}") + + elif tag == "insert": + out.append(f"{GREEN}{norm_text[j1:j2]}{RESET}") + + elif tag == "replace": + if i1 != i2: + out.append(f"{RED}{ori_text[i1:i2]}{RESET}") + if j1 != j2: + out.append(f"{GREEN}{norm_text[j1:j2]}{RESET}") + + return "".join(out) + + train_cuts = load_manifest_lazy( + "data/manifests/paraspeechcaps_cuts_train_base_shuf-selected.jsonl.gz" + ) + cnt_short = 0 + cnt_long = 0 + cnt_short_total = 0 + cnt_long_total = 0 + for cut in train_cuts: + gender = cut.supervisions[0].gender + speaking_rate = cut.supervisions[0].custom["speaking_rate"] + pitch = cut.supervisions[0].custom["pitch"] + accent = cut.supervisions[0].custom["accent"] + intrinsic_tags = cut.supervisions[0].custom["intrinsic_tags"] + situational_tags = cut.supervisions[0].custom["situational_tags"] + + short_captions = cut.supervisions[0].custom["short_captions"] + long_captions = cut.supervisions[0].custom["long_captions"] + + for short_caption in short_captions: + text = perturb_one_attribution_in_text( + short_caption, + gender, + speaking_rate, + pitch, + accent, + intrinsic_tags, + situational_tags, + ) + # print(color_diff_ori(short_caption, text)) + + for long_caption in long_captions: + cnt_long_total += 1 + text = perturb_one_attribution_in_text( + long_caption, + gender, + speaking_rate, + pitch, + accent, + intrinsic_tags, + situational_tags, + ) + # print(color_diff_ori(long_caption, text)) diff --git a/egs/emilia/CLAP/spear_roberta/clap_module.py b/egs/emilia/CLAP/spear_roberta/clap_module.py new file mode 100644 index 0000000000..803c7f922b --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/clap_module.py @@ -0,0 +1,165 @@ +import torch +import torch.distributed.nn +from torch import distributed as dist +from torch import nn as nn +from torch.nn import functional as F + + +def gather_features( + audio_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, +): + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat( + torch.distributed.nn.all_gather(audio_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + else: + gathered_audio_features = [ + torch.zeros_like(audio_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_audio_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.rank = rank + self.world_size = world_size + + def forward( + self, + audio_features, + text_features, + logit_scale, + multi_positive=False, + ): + device = audio_features.device + + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + ) + + if self.local_loss: + logits_per_audio = logit_scale * audio_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_audio_features.T + else: + logits_per_audio = ( + logit_scale * all_audio_features @ all_text_features.T + ) + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale * audio_features @ text_features.T + logits_per_text = logit_scale * text_features @ audio_features.T + + # calculated ground-truth + if multi_positive: + B_audio_local = audio_features.shape[0] + B_text_local = text_features.shape[0] + assert B_audio_local * 2 == B_text_local + B = B_audio_local + + if not self.local_loss: + num_audio_global = logits_per_audio.shape[0] + idx_audio = torch.arange(num_audio_global, device=device) + + rank_audio = idx_audio // B + local_audio = idx_audio % B + + pos1 = rank_audio * (2 * B) + local_audio + pos2 = pos1 + B + + num_text_global = logits_per_text.shape[0] + idx_text = torch.arange(num_text_global, device=device) + + rank_text = idx_text // (2 * B) + labels_text = rank_text * B + idx_text % B + else: + idx_local_audio = torch.arange(B, device=device) + pos1 = self.rank * (2 * B) + idx_local_audio + pos2 = pos1 + B + + idx_local_text = torch.arange(2 * B, device=device) + labels_text = self.rank * B + idx_local_text % B + + labels_audio = torch.zeros_like(logits_per_audio) + labels_audio.scatter_(1, pos1.unsqueeze(1), 0.5) + labels_audio.scatter_(1, pos2.unsqueeze(1), 0.5) + + total_loss = ( + F.cross_entropy(logits_per_audio, labels_audio) + + F.cross_entropy(logits_per_text, labels_text) + ) / 2 + + else: + num_logits = logits_per_audio.shape[0] + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return total_loss + + +def local_clip_loss( + audio_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, +) -> torch.Tensor: + B = audio_features.shape[0] + + assert text_features.shape[0] == B + assert text_features.shape[1] == 2 + + logits = logit_scale * (audio_features.unsqueeze(1) * text_features).sum(dim=-1) + + # logsumexp(pos) = log(e^P1) + log_sum_exp_pos = torch.logsumexp(logits[:, :1], dim=1) + + # logsumexp(all) = log(e^P1 + e^N1) + log_sum_exp_all = torch.logsumexp(logits, dim=1) + + # Loss = - log ( sum(exp(pos)) / sum(exp(all)) ) + # = - ( log(sum(exp(pos))) - log(sum(exp(all))) ) + # = log_sum_exp_all - log_sum_exp_pos + loss = log_sum_exp_all - log_sum_exp_pos + + return loss.mean() diff --git a/egs/emilia/CLAP/spear_roberta/encoder_interface.py b/egs/emilia/CLAP/spear_roberta/encoder_interface.py new file mode 120000 index 0000000000..c2eaca6712 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear_roberta/evaluate_retrieval.py b/egs/emilia/CLAP/spear_roberta/evaluate_retrieval.py new file mode 100755 index 0000000000..27de650e66 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/evaluate_retrieval.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import logging +from pathlib import Path + +import torch +from asr_datamodule import DataModule +from finetune_stage2 import add_model_arguments, evaluate, get_model, get_params +from transformers import RobertaTokenizer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "speech-text-retrieval" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + # filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + # torch.save({"model": model.state_dict()}, filename) + # exit() + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + paraspeechcaps_test_cuts = datamodule.paraspeechcaps_test_cuts() + paraspeechcaps_test_dl = datamodule.test_dataloaders(paraspeechcaps_test_cuts) + + test_sets = [ + "paraspeechcaps_test", + ] + test_dls = [ + paraspeechcaps_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=test_dl, + caption_type="short_captions", + return_details=True, + ) + metrics = result_dict["metrics"] + details = result_dict["details"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + with open( + f"{params.res_dir}/details-decode-{params.suffix}", "w", encoding="utf-8" + ) as f: + json.dump(details, f, ensure_ascii=False, indent=2) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear_roberta/evaluate_zero_shot_classification.py b/egs/emilia/CLAP/spear_roberta/evaluate_zero_shot_classification.py new file mode 100755 index 0000000000..dfadd736b3 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/evaluate_zero_shot_classification.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from typing import Dict + +import torch +import torch.nn as nn +from asr_datamodule import DataModule +from finetune_stage2 import add_model_arguments, get_model, get_params +from transformers import RobertaTokenizer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +def map_iemocap_emotion_label_to_index(label: str) -> int: + label_map = { + "hap": 0, + "exc": 1, + "ang": 2, + "sad": 3, + "neu": 4, + } + return label_map[label] + + +def map_ravdess_emotion_label_to_index(label: str) -> int: + label_map = { + "angry": 0, + "calm": 1, + "disgust": 2, + "fearful": 3, + "happy": 4, + "sad": 5, + "surprised": 6, + "neutral": 7, + } + return label_map[label] + + +def map_ravdess_gender_label_to_index(label: str) -> int: + label_map = { + "male": 0, + "female": 1, + } + return label_map[label] + + +def map_cremad_emotion_label_to_index(label: str) -> int: + label_map = { + "H": 0, + "S": 1, + "A": 2, + "F": 3, + "D": 4, + "N": 5, + } + return label_map[label] + + +def map_cremad_age_label_to_index(label: str) -> int: + if label < 20: + index = 0 + elif label < 40: + index = 1 + elif label < 60: + index = 2 + else: + index = 3 + return index + + +def generate_iemocap_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a excited tone.", + "A speaker in a angry tone.", + "A speaker in a sad tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_emotion_prompts() -> str: + return [ + "A speaker in a angry tone.", + "A speaker in a calm tone.", + "A speaker in a disgust tone.", + "A speaker in a fear tone.", + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a surprised tone.", + "A speaker in a neutral tone.", + ] + + +def generate_ravdess_gender_prompts() -> str: + return [ + "A male speaker.", + "A female speaker.", + ] + + +def generate_cremad_emotion_prompts() -> str: + return [ + "A speaker in a happy tone.", + "A speaker in a sad tone.", + "A speaker in a angry tone.", + "A speaker in a fear tone.", + "A speaker in a disgust tone.", + "A speaker in a neutral tone.", + ] + + +def generate_cremad_age_prompts() -> str: + return [ + "A child or young teenager speaker.", + "An adult speaker.", + "A middle-aged speaker.", + "An older or elder speaker.", + ] + + +def evaluate( + params: AttributeDict, + model: nn.Module, + tokenizer: RobertaTokenizer, + test_set: str, + test_dl: torch.utils.data.DataLoader, +) -> Dict[str, float]: + """Run the Zero-Shot Classification validation process.""" + model.eval() + device = next(model.parameters()).device + + metrics = {} + eval_info = { + "all_audio_features": [], + "all_gt_labels": [], + } + + if test_set == "iemocap_emotion": + prompts = generate_iemocap_emotion_prompts() + elif test_set == "ravdess_emotion": + prompts = generate_ravdess_emotion_prompts() + elif test_set == "ravdess_gender": + prompts = generate_ravdess_gender_prompts() + elif test_set == "cremad_emotion": + prompts = generate_cremad_emotion_prompts() + elif test_set == "cremad_age": + prompts = generate_cremad_age_prompts() + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + text = tokenizer( + prompts, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + _, text_features, _ = model( + text=text, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + with torch.no_grad(): + for batch_idx, batch in enumerate(test_dl): + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + if test_set == "iemocap_emotion": + gt_labels = [ + map_iemocap_emotion_label_to_index( + c.supervisions[0].custom["emotion"] + ) + for c in batch["supervisions"]["cut"] + ] + elif test_set == "ravdess_emotion": + gt_labels = [ + map_ravdess_emotion_label_to_index( + c.supervisions[0].custom["emotion"] + ) + for c in batch["supervisions"]["cut"] + ] + elif test_set == "ravdess_gender": + gt_labels = [ + map_ravdess_gender_label_to_index(c.supervisions[0].gender) + for c in batch["supervisions"]["cut"] + ] + elif test_set == "cremad_emotion": + gt_labels = [ + map_cremad_emotion_label_to_index( + c.supervisions[0].custom["emotion"] + ) + for c in batch["supervisions"]["cut"] + ] + elif test_set == "cremad_age": + gt_labels = [ + map_cremad_age_label_to_index(c.supervisions[0].custom["age"]) + for c in batch["supervisions"]["cut"] + ] + else: + raise NotImplementedError(f"Unknown test set: {test_set}") + + audio_features, _, _ = model( + audio=feature, + audio_lens=feature_lens, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_gt_labels"].extend(gt_labels) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + metrics_single_dataset = compute_metrics( + audio_features=torch.cat(eval_info["all_audio_features"]), + text_features=text_features.cpu(), + gt_labels=torch.tensor(eval_info["all_gt_labels"], dtype=torch.int64), + test_set=test_set, + ) + metrics.update(metrics_single_dataset) + + result_dict = {"metrics": metrics} + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + gt_labels: torch.Tensor, + test_set: str, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + + logits_per_audio = torch.matmul(audio_features, text_features.t()) + preds = logits_per_audio.argmax(dim=1) + + if test_set == "iemocap_emotion": + gt_labels = gt_labels.clamp(min=1) + preds = preds.clamp(min=1) + + wa = (preds == gt_labels).float().mean().item() + + recall_sum = 0.0 + num_classes = 0 + for cls_idx in torch.unique(gt_labels): + cls_idx = cls_idx.item() + cls_mask = gt_labels == cls_idx + recall = (preds[cls_mask] == cls_idx).float().mean().item() + recall_sum += recall + num_classes += 1 + logging.info(f"{test_set}: cls {cls_idx}, recall {recall}") + uar = recall_sum / num_classes if num_classes > 0 else 0.0 + + return {"wa": wa, "uar": uar} + + +@torch.no_grad() +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "zero-shot-classification" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + # filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + # torch.save({"model": model.state_dict()}, filename) + # exit() + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + datamodule = DataModule(args) + + iemocap_test_cuts = datamodule.iemocap_cuts() + iemocap_test_dl = datamodule.test_dataloaders(iemocap_test_cuts) + + ravdess_test_cuts = datamodule.ravdess_cuts() + ravdess_test_dl = datamodule.test_dataloaders(ravdess_test_cuts) + + cremad_test_cuts = datamodule.cremad_cuts() + cremad_test_dl = datamodule.test_dataloaders(cremad_test_cuts) + + test_sets = [ + "iemocap_emotion", + "ravdess_emotion", + "cremad_emotion", + "ravdess_gender", + "cremad_age", + ] + test_dls = [ + iemocap_test_dl, + ravdess_test_dl, + cremad_test_dl, + ravdess_test_dl, + cremad_test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + result_dict = evaluate( + params=params, + model=model, + tokenizer=tokenizer, + test_set=test_set, + test_dl=test_dl, + ) + metrics = result_dict["metrics"] + logging.info( + f"{test_set}: " + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear_roberta/export_retrieval_ranks.py b/egs/emilia/CLAP/spear_roberta/export_retrieval_ranks.py new file mode 100755 index 0000000000..56ff3c313f --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/export_retrieval_ranks.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from pathlib import Path + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="spear_roberta/exp", + help="The experiment dir", + ) + + return parser + + +def export_audio_to_text(details, output_dir): + """ + audio_to_text_ranks: + audio -> [text0, text1, ...] + """ + for idx, (audio_path, texts) in enumerate(details.items()): + item_dir = output_dir / str(idx) + item_dir.mkdir(parents=True, exist_ok=True) + + audio_path = Path(audio_path) + os.symlink(audio_path.resolve(), item_dir / audio_path.name) + + for rank, text in enumerate(texts): + with open(item_dir / f"{rank}.txt", "w", encoding="utf-8") as f: + f.write(text + "\n") + + +def export_text_to_audio(details, output_dir): + """ + text_to_audio_ranks: + text -> [audio0, audio1, ...] + """ + + for idx, (text, audio_paths) in enumerate(details.items()): + item_dir = output_dir / str(idx) + item_dir.mkdir(parents=True, exist_ok=True) + + with open(item_dir / "text.txt", "w", encoding="utf-8") as f: + f.write(text + "\n") + + for rank, audio_path in enumerate(audio_paths): + audio_path = audio_path.replace("GT# ", "") + audio_path = Path(audio_path) + os.symlink(audio_path.resolve(), item_dir / f"{rank}{audio_path.suffix}") + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + args.res_dir = args.exp_dir / "speech-text-retrieval" + + if args.iter > 0: + args.suffix = f"iter-{args.iter}-avg-{args.avg}" + else: + args.suffix = f"epoch-{args.epoch}-avg-{args.avg}" + + if args.use_averaged_model: + args.suffix += "-use-averaged-model" + + with open(f"{args.res_dir}/details-decode-{args.suffix}", encoding="utf-8") as f: + details = json.load(f) + + export_audio_to_text( + details["audio_to_text_ranks"], + args.res_dir / args.suffix / "audio_to_text_ranks", + ) + + export_text_to_audio( + details["text_to_audio_ranks"], + args.res_dir / args.suffix / "text_to_audio_ranks", + ) + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear_roberta/finetune_stage1.py b/egs/emilia/CLAP/spear_roberta/finetune_stage1.py new file mode 100644 index 0000000000..f8e88a0aad --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/finetune_stage1.py @@ -0,0 +1,1613 @@ +#!/usr/bin/env python3 +# Copyright 2021-2025 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import copy +import json +import logging +import math +import random +import warnings + +warnings.filterwarnings("ignore", category=FutureWarning) +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from asr_datamodule import DataModule +from clap_module import ClipLoss +from lhotse.utils import fix_random_seed +from model import CLAP +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from transformers import RobertaTokenizer +from zipformer2 import SimpleDownsample, Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + compare_model, + create_grad_scaler, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def unwrap_model(model: Union[nn.Module, DDP]) -> nn.Module: + if hasattr(model, "module"): + return model.module + else: + return model + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + batch_count = ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + if params.large_batch_count: + batch_count += 100000 + return batch_count + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + model = unwrap_model(model) + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used", + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true; -1 means not freezing", + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + # audio encoder + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--post-encoder-downsampling-factor", + type=int, + default=1, + help="The downsampling factor after the zipformer encoder", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # text encoder + parser.add_argument( + "--text-encoder-dim", + type=int, + default=768, + help="Embedding dimension in the text encoder model.", + ) + + # joiner + parser.add_argument( + "--joint-dim", + type=int, + default=512, + help="Dimension used in the joiner model.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=False, + help="""True if using multi-node multi-GPU. + You are not supposed to set it directly. + """, + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="The initial value of warmup, between 0 and 1", + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0, + help="The number of batches for lr warmup", + ) + + parser.add_argument( + "--large-batch-count", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=100000, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + parser.add_argument( + "--use-local-loss", + type=str2bool, + default=False, + help="""Whether to use local-only CLIP loss. If True, no cross-GPU + feature gather is performed, which saves communication and memory + but may reduce performance. + """, + ) + + parser.add_argument( + "--gather-with-grad", + type=str2bool, + default=False, + help="""Whether to allow gradients to flow through cross-GPU feature + gathering during Clip loss computation. If True, the gathered + global features retain gradient and participate in back-propagation, + providing a more complete optimization signal but increasing + communication cost and memory usage. If False, gathered features are + detached, reducing overhead but gradients only flow for local samples. + """, + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_train_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_criterion(params: AttributeDict, world_size: int, rank: int) -> nn.Module: + criterion = ClipLoss( + local_loss=params.use_local_loss, + gather_with_grad=params.gather_with_grad, + world_size=world_size, + rank=rank, + ) + return criterion + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if params.output_downsampling_factor == 2: + assert ( + params.post_encoder_downsampling_factor == 1 + ), "CANNOT perform double output downsample!" + + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_encoder_downsample_module(params: AttributeDict) -> nn.Module: + if params.post_encoder_downsampling_factor > 1: + downsample_module = SimpleDownsample( + max(_to_int_tuple(params.encoder_dim)), + downsample=params.post_encoder_downsampling_factor, + dropout=0.0, + ) + else: + downsample_module = None + return downsample_module + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + post_encoder_downsample = get_encoder_downsample_module(params) + + # modify the subsampling_factor accordingly + if params.output_downsampling_factor == 1: + params.subsampling_factor = 2 + + model = CLAP( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_downsample=post_encoder_downsample, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + text_encoder_dim=params.text_encoder_dim, + joint_dim=params.joint_dim, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "batch_idx_train", + "best_train_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[Any] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + criterion: nn.Module, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + text = tokenizer( + batch["supervisions"]["text"], + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + batch_idx_train = params.batch_idx_train + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info( + f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}." + ) + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + audio_features, text_features, logit_scale = model( + audio=feature, + audio_lens=feature_lens, + text=text, + freeze_audio_encoder=False, + freeze_text_encoder=False, + ) + loss = criterion( + audio_features=audio_features, + text_features=text_features, + logit_scale=logit_scale, + ) + + info = MetricsTracker() + batch_size = len(batch["supervisions"]["text"]) + info["utterances"] = batch_size + info["utt_clip_loss"] = loss.detach().cpu().item() * batch_size + + return loss, info + + +def evaluate( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + valid_dl: torch.utils.data.DataLoader, +) -> Dict[str, float]: + """Run the validation process.""" + model.eval() + + metrics = {} + num_samples = 0 + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "clip_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + captions = [ + c.supervisions[0].custom["long_captions"][1] + for c in batch["supervisions"]["cut"] + ] + text = tokenizer( + captions, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + audio_features, text_features, logit_scale = model( + audio=feature, + audio_lens=feature_lens, + text=text, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + num_samples += audio_features.shape[0] + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + metrics_single_dataset = compute_metrics( + audio_features=torch.cat(eval_info["all_audio_features"]), + text_features=torch.cat(eval_info["all_text_features"]), + logit_scale=logit_scale.cpu(), + ) + metrics.update(metrics_single_dataset) + + return metrics + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + logits_per_audio = logit_scale * audio_features @ text_features.t() + logits_per_text = logits_per_audio.t() + + labels = torch.arange(N, dtype=torch.long) + + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics = {} + metrics["clip_loss"] = total_loss.item() + metrics["num_samples"] = N + + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + # preds = torch.where(ranking == ground_truth)[1] + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + criterion: nn.Module, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + device = torch.device("cuda", torch.cuda.current_device()) + + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=0, + ) + + def slice_batch(batch, n): + if isinstance(batch, dict): + return {k: slice_batch(v, n) for k, v in batch.items()} + if isinstance(batch, tuple): + return tuple(slice_batch(v, n) for v in batch) + if isinstance(batch, list): + return [slice_batch(v, n) for v in batch[:n]] + if isinstance(batch, torch.Tensor): + if batch.dim() == 0: + return batch + return batch[:n] + return batch + + train_iter = iter(train_dl) + batch_idx = -1 + while True: + batch_idx += 1 + + try: + batch = next(train_iter) + batch_size = len(batch["supervisions"]["text"]) + except StopIteration: + batch_size = 0 + + if world_size > 1: + t = torch.tensor([batch_size], dtype=torch.int64, device=device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MIN) + min_batch_size = int(t.item()) + + if min_batch_size == 0: + batch_size = 0 + else: + if batch_size > min_batch_size: + batch = slice_batch(batch, min_batch_size) + batch_size = min_batch_size + + if batch_size == 0: + logging.info(f"Epoch {params.cur_epoch} finished.") + train_dl.sampler.cuts_iter.close() + break + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + criterion=criterion, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==mean and loss is computed over utterances + # in the batch. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + optimizer.zero_grad() + except Exception as e: # noqa + logging.warning(e) + save_bad_model() + display_and_save_batch(batch, params=params) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + cur_batch_idx = batch_idx + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {cur_batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if ( + 0 + and batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): + if rank == 0: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Do validation on {valid_set}") + metrics = evaluate( + params=params, + model=unwrap_model(model), + tokenizer=tokenizer, + valid_dl=valid_dl, + ) + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"validation on {valid_set}, " + + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + for name, val in metrics.items(): + tb_writer.add_scalar( + f"valid/{valid_set}-{name}", val, params.batch_idx_train + ) + with open( + f"{params.exp_dir}/log/log-valid-{valid_set}.jsonl", "a+" + ) as f: + f.write(json.dumps(metrics) + "\n") + + loss_value = tot_loss["utt_clip_loss"] / tot_loss["utterances"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + if params.use_multi_node: + local_rank = get_local_rank() + else: + local_rank = rank + logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}") + + if world_size > 1: + setup_dist(rank, world_size, params.master_port, params.use_multi_node) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") + + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + if params.start_epoch > 1: + # resuming training + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + # training from scratch + checkpoints = None + + # Setting the encoder lr scale + logging.info( + f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}" + ) + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + model.text_encoder.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + criterion = get_criterion(params, world_size, rank) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + warmup_batches=params.warmup_batches, + warmup_start=params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + datamodule = DataModule(args) + + train_cuts = datamodule.emilia_en_cuts() + + def remove_short_and_long_utt(c: Any): + # Keep only utterances with duration between 4 second and 30 seconds + if c.duration < 4.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if rank == 0: + duration_bins = datamodule.estimate_duration_bins( + cuts=train_cuts, + world_size=world_size, + rank=rank, + ) + datamodule.args.duration_bins = duration_bins + logging.info(f"Duration bins: {duration_bins}") + if world_size > 1: + obj_list = [duration_bins if rank == 0 else None] + torch.distributed.broadcast_object_list(obj_list, src=0) + datamodule.args.duration_bins = obj_list[0] + + last_upper = 30.0 + datamodule.args.max_seq_len_buckets = datamodule.args.duration_bins + [last_upper] + datamodule.args.fixed_batch_sizes = [ + max(1, int(params.max_duration // ub)) + for ub in datamodule.args.max_seq_len_buckets + ] + + # construct the training dataloader + train_dl = datamodule.train_dataloaders( + train_cuts, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + if 0 and rank == 0: + valid_cuts = datamodule.dev_clean_cuts() + valid_sets.append("librispeech") + valid_dls.append( + datamodule.valid_dataloaders( + valid_cuts, + world_size=1, + rank=rank, + ) + ) + + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + tokenizer=tokenizer, + optimizer=optimizer, + scheduler=scheduler, + criterion=criterion, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + features = batch["inputs"] + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + run(rank=rank, world_size=world_size, args=args) + else: + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear_roberta/finetune_stage1.sh b/egs/emilia/CLAP/spear_roberta/finetune_stage1.sh new file mode 100755 index 0000000000..b914863fb7 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/finetune_stage1.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +export PYTHONPATH=/root/icefall:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=$1 + +lr=0.045 + +# finetune checkpoint +do_finetune=1 +finetune_ckpt=download/iter-448000-avg-2.pt + +output_ds=2 +post_output_ds=1 + +freeze_encoder=0 +freeze_encoder_steps=-1 +encoder_lr_scale=0.02222 + +md=800 + +exp_dir=spear_roberta/exp_ft + +echo $exp_dir + +if false; then +python spear_roberta/finetune_stage1.py \ + --world-size 8 \ + --num-epochs 100 \ + --use-fp16 0 \ + --use-bf16 1 \ + --start-epoch 1 \ + --exp-dir $exp_dir \ + --manifest-dir data/manifests \ + --base-lr $lr \ + --do-finetune $do_finetune --init-modules "encoder_embed,encoder" --finetune-ckpt $finetune_ckpt \ + --freeze-encoder $freeze_encoder --freeze-encoder-steps $freeze_encoder_steps \ + --encoder-lr-scale $encoder_lr_scale \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --enable-musan 0 \ + --enable-spec-aug 0 \ + --max-duration $md +fi + +if true; then +epoch=$2 +# avg=$3 +for epoch in $(seq $epoch 1 $((epoch + 4))); do +for avg in $(seq 2 1 $((epoch - 1))); do + python spear_roberta/evaluate_retrieval.py \ + --epoch $epoch \ + --avg $avg \ + --manifest-dir data/manifests \ + --use-averaged-model 1 \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md + done +done +fi + +python /root/busygpu/run.py diff --git a/egs/emilia/CLAP/spear_roberta/finetune_stage2.py b/egs/emilia/CLAP/spear_roberta/finetune_stage2.py new file mode 100644 index 0000000000..d297dd1320 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/finetune_stage2.py @@ -0,0 +1,1698 @@ +#!/usr/bin/env python3 +# Copyright 2021-2025 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import copy +import json +import logging +import math +import random +import warnings + +warnings.filterwarnings("ignore", category=FutureWarning) +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from asr_datamodule import DataModule +from clap_module import ClipLoss +from lhotse.utils import fix_random_seed +from model import CLAP +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from transformers import RobertaTokenizer +from zipformer2 import SimpleDownsample, Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + compare_model, + create_grad_scaler, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def unwrap_model(model: Union[nn.Module, DDP]) -> nn.Module: + if hasattr(model, "module"): + return model.module + else: + return model + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + batch_count = ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + if params.large_batch_count: + batch_count += 100000 + return batch_count + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + model = unwrap_model(model) + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used", + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true; -1 means not freezing", + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + # audio encoder + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--post-encoder-downsampling-factor", + type=int, + default=1, + help="The downsampling factor after the zipformer encoder", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # text encoder + parser.add_argument( + "--text-encoder-dim", + type=int, + default=768, + help="Embedding dimension in the text encoder model.", + ) + + # joiner + parser.add_argument( + "--joint-dim", + type=int, + default=512, + help="Dimension used in the joiner model.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=False, + help="""True if using multi-node multi-GPU. + You are not supposed to set it directly. + """, + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="The initial value of warmup, between 0 and 1", + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0, + help="The number of batches for lr warmup", + ) + + parser.add_argument( + "--large-batch-count", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=100000, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + parser.add_argument( + "--use-local-loss", + type=str2bool, + default=False, + help="""Whether to use local-only CLIP loss. If True, no cross-GPU + feature gather is performed, which saves communication and memory + but may reduce performance. + """, + ) + + parser.add_argument( + "--gather-with-grad", + type=str2bool, + default=False, + help="""Whether to allow gradients to flow through cross-GPU feature + gathering during Clip loss computation. If True, the gathered + global features retain gradient and participate in back-propagation, + providing a more complete optimization signal but increasing + communication cost and memory usage. If False, gathered features are + detached, reducing overhead but gradients only flow for local samples. + """, + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_train_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_criterion(params: AttributeDict, world_size: int, rank: int) -> nn.Module: + criterion = ClipLoss( + local_loss=params.use_local_loss, + gather_with_grad=params.gather_with_grad, + world_size=world_size, + rank=rank, + ) + return criterion + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if params.output_downsampling_factor == 2: + assert ( + params.post_encoder_downsampling_factor == 1 + ), "CANNOT perform double output downsample!" + + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_encoder_downsample_module(params: AttributeDict) -> nn.Module: + if params.post_encoder_downsampling_factor > 1: + downsample_module = SimpleDownsample( + max(_to_int_tuple(params.encoder_dim)), + downsample=params.post_encoder_downsampling_factor, + dropout=0.0, + ) + else: + downsample_module = None + return downsample_module + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + post_encoder_downsample = get_encoder_downsample_module(params) + + # modify the subsampling_factor accordingly + if params.output_downsampling_factor == 1: + params.subsampling_factor = 2 + + model = CLAP( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_downsample=post_encoder_downsample, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + text_encoder_dim=params.text_encoder_dim, + joint_dim=params.joint_dim, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "batch_idx_train", + "best_train_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[Any] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + criterion: nn.Module, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + batch_idx_train = params.batch_idx_train + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + def get_p(batch_idx_train): + return min(0.05 + (0.50 - 0.05) * batch_idx_train / 10000, 0.50) + + task_list = ["couple_of_long", "short_and_long"] + current_prob = get_p(batch_idx_train) + task_probs = [current_prob, 1 - current_prob] + current_task = random.choices(task_list, weights=task_probs, k=1)[0] + + if current_task == "short_and_long": + short_captions = [ + random.choice(c.supervisions[0].custom["short_captions"]) + for c in batch["supervisions"]["cut"] + ] + long_captions = [ + random.choice(c.supervisions[0].custom["long_captions"]) + for c in batch["supervisions"]["cut"] + ] + captions = short_captions + long_captions + elif current_task == "couple_of_long": + sampled_pairs = [ + random.sample(c.supervisions[0].custom["long_captions"], 2) + for c in batch["supervisions"]["cut"] + ] + long_captions1 = [pair[0] for pair in sampled_pairs] + long_captions2 = [pair[1] for pair in sampled_pairs] + captions = long_captions1 + long_captions2 + + text = tokenizer( + captions, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info( + f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}." + ) + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + audio_features, text_features, logit_scale = model( + audio=feature, + audio_lens=feature_lens, + text=text, + freeze_audio_encoder=False, + freeze_text_encoder=False, + ) + loss = criterion( + audio_features=audio_features, + text_features=text_features, + logit_scale=logit_scale, + multi_positive=True, + ) + + info = MetricsTracker() + batch_size = len(batch["supervisions"]["cut"]) + info["utterances"] = batch_size + info["utt_clip_loss"] = loss.detach().cpu().item() * batch_size + + return loss, info + + +def evaluate( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + valid_dl: torch.utils.data.DataLoader, + caption_type: str, + return_details: bool = False, +) -> Dict[str, float]: + """Run the validation process.""" + model.eval() + + metrics = {} + num_samples = 0 + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "clip_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + eval_detail = { + "all_audio_paths": [], + "all_texts": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + if caption_type == "short_captions": + captions = [ + c.supervisions[0].custom[caption_type][0] + for c in batch["supervisions"]["cut"] + ] + elif caption_type == "long_captions": + captions = [ + c.supervisions[0].custom[caption_type][-1] + for c in batch["supervisions"]["cut"] + ] + else: + raise ValueError + + text = tokenizer( + captions, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + audio_features, text_features, logit_scale = model( + audio=feature, + audio_lens=feature_lens, + text=text, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + num_samples += audio_features.shape[0] + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + if return_details: + eval_detail["all_audio_paths"].extend( + [ + c.recording.sources[0].source + for c in batch["supervisions"]["cut"] + ] + ) + eval_detail["all_texts"].extend(captions) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + metrics_single_dataset, details_single_dataset = compute_metrics( + audio_features=torch.cat(eval_info["all_audio_features"]), + text_features=torch.cat(eval_info["all_text_features"]), + logit_scale=logit_scale.cpu(), + ) + metrics.update(metrics_single_dataset) + + if return_details: + details = {} + for k, ranks in details_single_dataset.items(): + if k == "audio_to_text_ranks": + src_list = eval_detail["all_audio_paths"] + tgt_list = eval_detail["all_texts"] + elif k == "text_to_audio_ranks": + src_list = eval_detail["all_texts"] + tgt_list = eval_detail["all_audio_paths"] + else: + raise ValueError + + details[k] = { + src_list[i]: [ + f"GT# {tgt_list[j]}" if j == i else tgt_list[j] for j in ranking + ] + for i, ranking in enumerate(ranks) + } + + result_dict = {"metrics": metrics} + if return_details: + result_dict["details"] = details + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + logits_per_audio = logit_scale * audio_features @ text_features.t() + logits_per_text = logits_per_audio.t() + + labels = torch.arange(N, dtype=torch.long) + + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics = {} + metrics["clip_loss"] = total_loss.item() + metrics["num_samples"] = N + + details = {} + + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + # preds = torch.where(ranking == ground_truth)[1] + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + details[f"{name}_ranks"] = ranking.detach().cpu().tolist() + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics, details + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + criterion: nn.Module, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + device = torch.device("cuda", torch.cuda.current_device()) + + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=0, + ) + + def slice_batch(batch, n): + if isinstance(batch, dict): + return {k: slice_batch(v, n) for k, v in batch.items()} + if isinstance(batch, tuple): + return tuple(slice_batch(v, n) for v in batch) + if isinstance(batch, list): + return [slice_batch(v, n) for v in batch[:n]] + if isinstance(batch, torch.Tensor): + if batch.dim() == 0: + return batch + return batch[:n] + return batch + + train_iter = iter(train_dl) + batch_idx = -1 + while True: + batch_idx += 1 + + try: + batch = next(train_iter) + batch_size = len(batch["supervisions"]["cut"]) + except StopIteration: + batch_size = 0 + + if world_size > 1: + t = torch.tensor([batch_size], dtype=torch.int64, device=device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MIN) + min_batch_size = int(t.item()) + + if min_batch_size == 0: + batch_size = 0 + else: + if batch_size > min_batch_size: + batch = slice_batch(batch, min_batch_size) + batch_size = min_batch_size + + if batch_size == 0: + logging.info(f"Epoch {params.cur_epoch} finished.") + train_dl.sampler.cuts_iter.close() + break + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + criterion=criterion, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==mean and loss is computed over utterances + # in the batch. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + optimizer.zero_grad() + except Exception as e: # noqa + logging.warning(e) + save_bad_model() + display_and_save_batch(batch, params=params) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + cur_batch_idx = batch_idx + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {cur_batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if rank == 0: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Do validation on {valid_set}") + for caption_type in ["short_captions", "long_captions"]: + metrics = evaluate( + params=params, + model=unwrap_model(model), + tokenizer=tokenizer, + valid_dl=valid_dl, + caption_type=caption_type, + return_details=False, + )["metrics"] + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"validation on {valid_set}, " + f"{caption_type}, " + + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + for name, val in metrics.items(): + tb_writer.add_scalar( + f"valid/{valid_set}-{caption_type}-{name}", + val, + params.batch_idx_train, + ) + with open( + f"{params.exp_dir}/log/log-valid-{valid_set}-{caption_type}.jsonl", + "a+", + ) as f: + f.write(json.dumps(metrics) + "\n") + + loss_value = tot_loss["utt_clip_loss"] / tot_loss["utterances"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + if params.use_multi_node: + local_rank = get_local_rank() + else: + local_rank = rank + logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}") + + if world_size > 1: + setup_dist(rank, world_size, params.master_port, params.use_multi_node) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") + + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + if params.start_epoch > 1: + # resuming training + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + # training from scratch + checkpoints = None + + # Setting the encoder lr scale + logging.info( + f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}" + ) + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + model.text_encoder.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + criterion = get_criterion(params, world_size, rank) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + warmup_batches=params.warmup_batches, + warmup_start=params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + datamodule = DataModule(args) + + train_cuts = datamodule.paraspeechcaps_train_base_cuts() + + def remove_short_and_long_utt(c: Any): + # Keep only utterances with duration between 2 second and 30 seconds + if c.duration < 2.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if rank == 0: + duration_bins = datamodule.estimate_duration_bins( + cuts=train_cuts, + world_size=world_size, + rank=rank, + ) + datamodule.args.duration_bins = duration_bins + logging.info(f"Duration bins: {duration_bins}") + if world_size > 1: + obj_list = [duration_bins if rank == 0 else None] + torch.distributed.broadcast_object_list(obj_list, src=0) + datamodule.args.duration_bins = obj_list[0] + + last_upper = 30.0 + datamodule.args.max_seq_len_buckets = datamodule.args.duration_bins + [last_upper] + datamodule.args.fixed_batch_sizes = [ + max(1, int(params.max_duration // ub)) + for ub in datamodule.args.max_seq_len_buckets + ] + + # construct the training dataloader + train_dl = datamodule.train_dataloaders( + train_cuts, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + if rank == 0: + valid_cuts = datamodule.paraspeechcaps_test_cuts() + valid_sets.append("paraspeechcaps test") + valid_dls.append( + datamodule.valid_dataloaders( + valid_cuts, + world_size=1, + rank=rank, + ) + ) + + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + tokenizer=tokenizer, + optimizer=optimizer, + scheduler=scheduler, + criterion=criterion, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + features = batch["inputs"] + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + run(rank=rank, world_size=world_size, args=args) + else: + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear_roberta/finetune_stage2.sh b/egs/emilia/CLAP/spear_roberta/finetune_stage2.sh new file mode 100755 index 0000000000..7912376f2b --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/finetune_stage2.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash + +export PYTHONPATH=/root/icefall:$PYTHONPATH +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export CUDA_VISIBLE_DEVICES=$1 + +lr=0.001 + +# finetune checkpoint +do_finetune=1 +finetune_ckpt=download/stage1-epoch-45-avg-28.pt + +output_ds=2 +post_output_ds=1 + +freeze_encoder=0 +freeze_encoder_steps=-1 +encoder_lr_scale=1 + +md=800 + +exp_dir=spear_roberta/exp_ft + +echo $exp_dir + +if false; then +python spear_roberta/finetune_stage2.py \ + --world-size 8 \ + --num-epochs 400 \ + --use-fp16 0 \ + --use-bf16 1 \ + --start-epoch 1 \ + --exp-dir $exp_dir \ + --manifest-dir data/manifests \ + --base-lr $lr \ + --do-finetune $do_finetune --finetune-ckpt $finetune_ckpt \ + --freeze-encoder $freeze_encoder --freeze-encoder-steps $freeze_encoder_steps \ + --encoder-lr-scale $encoder_lr_scale \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --enable-musan 0 \ + --enable-spec-aug 0 \ + --max-duration $md +fi + +if false; then +epoch=$2 +# avg=$3 +for epoch in $(seq $epoch 5 $((epoch + 24))); do +for avg in $(seq 2 5 $((epoch - 1))); do + python spear_roberta/evaluate_retrieval.py \ + --epoch $epoch \ + --avg $avg \ + --manifest-dir data/manifests \ + --use-averaged-model 1 \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +done +done +fi + +if true; then +epoch=$2 +avg=$3 +# while read -r score tag; do + # epoch=$(echo "$tag" | awk -F'[-]' '{print $2}') + # avg=$(echo "$tag" | awk -F'[-]' '{print $4}') + python spear_roberta/evaluate_zero_shot_classification.py \ + --epoch $epoch \ + --avg $avg \ + --manifest-dir data/manifests \ + --use-averaged-model 1 \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md +# done < "$2" +fi + +# for i in {0..7}; do CUDA_VISIBLE_DEVICES=$i python /root/busygpu/run.py & done +# python /root/busygpu/run.py & diff --git a/egs/emilia/CLAP/spear_roberta/finetune_stage3.py b/egs/emilia/CLAP/spear_roberta/finetune_stage3.py new file mode 100644 index 0000000000..1d11350a8f --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/finetune_stage3.py @@ -0,0 +1,1740 @@ +#!/usr/bin/env python3 +# Copyright 2021-2025 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import copy +import json +import logging +import math +import random +import warnings + +warnings.filterwarnings("ignore", category=FutureWarning) +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from asr_datamodule import DataModule +from attribute_perturbation import perturb_one_attribution_in_text +from clap_module import ClipLoss, local_clip_loss +from lhotse.utils import fix_random_seed +from model import CLAP +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from transformers import RobertaTokenizer +from zipformer2 import SimpleDownsample, Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + compare_model, + create_grad_scaler, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def unwrap_model(model: Union[nn.Module, DDP]) -> nn.Module: + if hasattr(model, "module"): + return model.module + else: + return model + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + batch_count = ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + if params.large_batch_count: + batch_count += 100000 + return batch_count + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + model = unwrap_model(model) + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + parser.add_argument( + "--freeze-encoder", + type=str2bool, + default=False, + help="Freeze the encoder of the model. If true, freeze-encoder-steps won't be used", + ) + + parser.add_argument( + "--freeze-encoder-steps", + type=int, + default=-1, + help="For this number of steps, freeze the encoder. If set, freeze-encoder cannot be true; -1 means not freezing", + ) + + parser.add_argument( + "--encoder-lr-scale", + type=float, + default=1.0, + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + # audio encoder + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--output-downsampling-factor", + type=int, + default=2, + help="The outout downsampling factor. Default is 2. If 1, no downsample is performed.", + ) + + parser.add_argument( + "--post-encoder-downsampling-factor", + type=int, + default=1, + help="The downsampling factor after the zipformer encoder", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + # text encoder + parser.add_argument( + "--text-encoder-dim", + type=int, + default=768, + help="Embedding dimension in the text encoder model.", + ) + + # joiner + parser.add_argument( + "--joint-dim", + type=int, + default=512, + help="Dimension used in the joiner model.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=False, + help="""True if using multi-node multi-GPU. + You are not supposed to set it directly. + """, + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="The initial value of warmup, between 0 and 1", + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=500.0, + help="The number of batches for lr warmup", + ) + + parser.add_argument( + "--large-batch-count", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=100000, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + parser.add_argument( + "--use-local-loss", + type=str2bool, + default=False, + help="""Whether to use local-only CLIP loss. If True, no cross-GPU + feature gather is performed, which saves communication and memory + but may reduce performance. + """, + ) + + parser.add_argument( + "--gather-with-grad", + type=str2bool, + default=False, + help="""Whether to allow gradients to flow through cross-GPU feature + gathering during Clip loss computation. If True, the gathered + global features retain gradient and participate in back-propagation, + providing a more complete optimization signal but increasing + communication cost and memory usage. If False, gathered features are + detached, reducing overhead but gradients only flow for local samples. + """, + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_train_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 128, + "subsampling_factor": 4, # not passed in, this is fixed. + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_criterion(params: AttributeDict, world_size: int, rank: int) -> nn.Module: + criterion = ClipLoss( + local_loss=params.use_local_loss, + gather_with_grad=params.gather_with_grad, + world_size=world_size, + rank=rank, + ) + return criterion + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if params.output_downsampling_factor == 2: + assert ( + params.post_encoder_downsampling_factor == 1 + ), "CANNOT perform double output downsample!" + + encoder = Zipformer2( + output_downsampling_factor=params.output_downsampling_factor, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_encoder_downsample_module(params: AttributeDict) -> nn.Module: + if params.post_encoder_downsampling_factor > 1: + downsample_module = SimpleDownsample( + max(_to_int_tuple(params.encoder_dim)), + downsample=params.post_encoder_downsampling_factor, + dropout=0.0, + ) + else: + downsample_module = None + return downsample_module + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + post_encoder_downsample = get_encoder_downsample_module(params) + + # modify the subsampling_factor accordingly + if params.output_downsampling_factor == 1: + params.subsampling_factor = 2 + + model = CLAP( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_downsample=post_encoder_downsample, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + text_encoder_dim=params.text_encoder_dim, + joint_dim=params.joint_dim, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "batch_idx_train", + "best_train_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + assert dst_state_dict[key].shape == src_state_dict[key].shape + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[Any] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + criterion: nn.Module, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + batch_size = len(batch["supervisions"]["cut"]) + + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + task_list = ["short_and_long", "couple_of_long"] + task_probs = [1.00, 0.00] + current_task = random.choices(task_list, weights=task_probs, k=1)[0] + + if current_task == "short_and_long": + short_captions = [] + long_captions = [] + perturbed_short_captions = [] + for c in batch["supervisions"]["cut"]: + gender = c.supervisions[0].gender + speaking_rate = c.supervisions[0].custom["speaking_rate"] + pitch = c.supervisions[0].custom["pitch"] + accent = c.supervisions[0].custom["accent"] + intrinsic_tags = c.supervisions[0].custom["intrinsic_tags"] + situational_tags = c.supervisions[0].custom["situational_tags"] + short_caption = random.choice(c.supervisions[0].custom["short_captions"]) + long_caption = random.choice(c.supervisions[0].custom["long_captions"]) + perturbed_short_caption = perturb_one_attribution_in_text( + short_caption, + gender=gender, + speaking_rate=speaking_rate, + pitch=pitch, + accent=accent, + intrinsic_tags=intrinsic_tags, + situational_tags=situational_tags, + ) + short_captions.append(short_caption) + long_captions.append(long_caption) + perturbed_short_captions.append(perturbed_short_caption) + captions = short_captions + long_captions + perturbed_short_captions + elif current_task == "couple_of_long": + long_captions1 = [] + long_captions2 = [] + for c in batch["supervisions"]["cut"]: + long_caption1, long_caption2 = random.sample( + c.supervisions[0].custom["long_captions"], 2 + ) + long_captions1.append(long_caption1) + long_captions2.append(long_caption2) + captions = long_captions1 + long_captions2 + else: + raise ValueError(f"Unsupported task: {current_task}") + + text = tokenizer( + captions, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + batch_idx_train = params.batch_idx_train + + if params.freeze_encoder_steps > 0: + freeze_encoder = batch_idx_train < params.freeze_encoder_steps + if random.random() < 0.01 and is_training: + logging.info(f"Step: {batch_idx_train}. Freeze encoder: {freeze_encoder}") + if batch_idx_train == params.freeze_encoder_steps: + logging.info( + f"Reaching {params.freeze_encoder_steps}. Freeze encoder: {freeze_encoder}." + ) + else: + freeze_encoder = params.freeze_encoder + + with torch.set_grad_enabled(is_training): + audio_features, text_features, logit_scale = model( + audio=feature, + audio_lens=feature_lens, + text=text, + freeze_audio_encoder=False, + freeze_text_encoder=False, + ) + global_loss = criterion( + audio_features=audio_features, + text_features=text_features[: batch_size * 2], + logit_scale=logit_scale, + multi_positive=True, + ) + + if current_task == "short_and_long": + text_features_pos1 = text_features[0:batch_size] + text_features_neg1 = text_features[2 * batch_size : 3 * batch_size] + + local_loss = local_clip_loss( + audio_features=audio_features, + text_features=torch.stack( + [ + text_features_pos1, + text_features_neg1, + ], + dim=1, + ), # (B, 3, D) + logit_scale=logit_scale, + ) + else: + local_loss = None + + loss = global_loss + local_loss * 0.2 + + info = MetricsTracker() + info["utterances"] = batch_size + info["utt_clip_loss"] = loss.detach().cpu().item() * batch_size + info["utt_global_loss"] = global_loss.detach().cpu().item() * batch_size + info["utt_local_loss"] = ( + local_loss.detach().cpu().item() * batch_size if local_loss else None + ) + + return loss, info + + +def evaluate( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + valid_dl: torch.utils.data.DataLoader, + caption_type: str, + return_details: bool = False, +) -> Dict[str, float]: + """Run the validation process.""" + model.eval() + + metrics = {} + num_samples = 0 + # Note: this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = { + "clip_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + eval_detail = { + "all_audio_paths": [], + "all_texts": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + if caption_type == "short_captions": + captions = [ + c.supervisions[0].custom[caption_type][0] + for c in batch["supervisions"]["cut"] + ] + elif caption_type == "long_captions": + captions = [ + c.supervisions[0].custom[caption_type][-1] + for c in batch["supervisions"]["cut"] + ] + else: + raise ValueError + + text = tokenizer( + captions, + padding=True, + truncation=True, + return_tensors="pt", + ) + text = {k: v.to(device) for k, v in text.items()} + + audio_features, text_features, logit_scale = model( + audio=feature, + audio_lens=feature_lens, + text=text, + freeze_audio_encoder=True, + freeze_text_encoder=True, + ) + + num_samples += audio_features.shape[0] + + eval_info["all_audio_features"].append(audio_features.cpu()) + eval_info["all_text_features"].append(text_features.cpu()) + + if return_details: + eval_detail["all_audio_paths"].extend( + [ + c.recording.sources[0].source + for c in batch["supervisions"]["cut"] + ] + ) + eval_detail["all_texts"].extend(captions) + + if batch_idx % 100 == 0: + logging.info(f"Validation batch {batch_idx}") + + metrics_single_dataset, details_single_dataset = compute_metrics( + audio_features=torch.cat(eval_info["all_audio_features"]), + text_features=torch.cat(eval_info["all_text_features"]), + logit_scale=logit_scale.cpu(), + ) + metrics.update(metrics_single_dataset) + + if return_details: + details = {} + for k, ranks in details_single_dataset.items(): + if k == "audio_to_text_ranks": + src_list = eval_detail["all_audio_paths"] + tgt_list = eval_detail["all_texts"] + elif k == "text_to_audio_ranks": + src_list = eval_detail["all_texts"] + tgt_list = eval_detail["all_audio_paths"] + else: + raise ValueError + + details[k] = { + src_list[i]: [ + f"GT# {tgt_list[j]}" if j == i else tgt_list[j] for j in ranking + ] + for i, ranking in enumerate(ranks) + } + + result_dict = {"metrics": metrics} + if return_details: + result_dict["details"] = details + + return result_dict + + +@torch.no_grad() +def compute_metrics( + audio_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, +) -> Dict[str, float]: + assert audio_features.dim() == 2 and text_features.dim() == 2, "Shapes must match" + assert audio_features.shape[0] == text_features.shape[0], "Batch sizes must match" + assert audio_features.shape[1] == text_features.shape[1], "Feature dims must match" + + N = audio_features.shape[0] + + logits_per_audio = logit_scale * audio_features @ text_features.t() + logits_per_text = logits_per_audio.t() + + labels = torch.arange(N, dtype=torch.long) + + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics = {} + metrics["clip_loss"] = total_loss.item() + metrics["num_samples"] = N + + details = {} + + for name, logit in { + "audio_to_text": logits_per_audio, + "text_to_audio": logits_per_text, + }.items(): + ranking = torch.argsort(logit, dim=1, descending=True) + + # preds = torch.where(ranking == ground_truth)[1] + ranks = torch.empty_like(ranking) + ranks.scatter_(1, ranking, torch.arange(N).unsqueeze(0).expand(N, -1)) + idx = torch.arange(N) + preds = ranks[idx, idx] + + details[f"{name}_ranks"] = ranking.detach().cpu().tolist() + + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = (preds < k).float().mean().item() + + metrics[f"{name}_mAP@10"] = ( + torch.where( + preds < 10, + 1.0 / (preds.float() + 1.0), + torch.zeros_like(preds, dtype=torch.float), + ) + .mean() + .item() + ) + + return metrics, details + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: RobertaTokenizer, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + criterion: nn.Module, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + device = torch.device("cuda", torch.cuda.current_device()) + + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=0, + ) + + def slice_batch(batch, n): + if isinstance(batch, dict): + return {k: slice_batch(v, n) for k, v in batch.items()} + if isinstance(batch, tuple): + return tuple(slice_batch(v, n) for v in batch) + if isinstance(batch, list): + return [slice_batch(v, n) for v in batch[:n]] + if isinstance(batch, torch.Tensor): + if batch.dim() == 0: + return batch + return batch[:n] + return batch + + train_iter = iter(train_dl) + batch_idx = -1 + while True: + batch_idx += 1 + + try: + batch = next(train_iter) + batch_size = len(batch["supervisions"]["cut"]) + except StopIteration: + batch_size = 0 + + if world_size > 1: + t = torch.tensor([batch_size], dtype=torch.int64, device=device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MIN) + min_batch_size = int(t.item()) + + if min_batch_size == 0: + batch_size = 0 + else: + if batch_size > min_batch_size: + batch = slice_batch(batch, min_batch_size) + batch_size = min_batch_size + + if batch_size == 0: + logging.info(f"Epoch {params.cur_epoch} finished.") + train_dl.sampler.cuts_iter.close() + break + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + criterion=criterion, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==mean and loss is computed over utterances + # in the batch. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + optimizer.zero_grad() + except Exception as e: # noqa + logging.warning(e) + save_bad_model() + display_and_save_batch(batch, params=params) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + cur_batch_idx = batch_idx + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {cur_batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if rank == 0: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Do validation on {valid_set}") + for caption_type in ["short_captions", "long_captions"]: + metrics = evaluate( + params=params, + model=unwrap_model(model), + tokenizer=tokenizer, + valid_dl=valid_dl, + caption_type=caption_type, + return_details=False, + )["metrics"] + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"validation on {valid_set}, " + f"{caption_type}, " + + " ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + for name, val in metrics.items(): + tb_writer.add_scalar( + f"valid/{valid_set}-{caption_type}-{name}", + val, + params.batch_idx_train, + ) + with open( + f"{params.exp_dir}/log/log-valid-{valid_set}-{caption_type}.jsonl", + "a+", + ) as f: + f.write(json.dumps(metrics) + "\n") + + loss_value = tot_loss["utt_clip_loss"] / tot_loss["utterances"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + if params.use_multi_node: + local_rank = get_local_rank() + else: + local_rank = rank + logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}") + + if world_size > 1: + setup_dist(rank, world_size, params.master_port, params.use_multi_node) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") + + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + compare_model(model.state_dict(), model_avg.state_dict()) + model_avg = copy.deepcopy(model).to(torch.float64) + else: + if params.start_epoch > 1: + # resuming training + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + # training from scratch + checkpoints = None + + # Setting the encoder lr scale + logging.info( + f"Setting the lr scale of parameters in encoder and encoder_embed to {params.encoder_lr_scale}" + ) + if params.encoder_lr_scale != 1.0: + model.encoder.lr_scale = params.encoder_lr_scale + model.encoder_embed.lr_scale = params.encoder_lr_scale + model.text_encoder.lr_scale = params.encoder_lr_scale + + # Check the freezing encoder configuration + if params.freeze_encoder_steps > 0: + logging.info(f"Freeze the encoder for {params.freeze_encoder_steps} steps") + assert not params.freeze_encoder + if params.freeze_encoder: + logging.info(f"Freeze the encoder for the whole training") + assert params.freeze_encoder_steps < 0 + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + criterion = get_criterion(params, world_size, rank) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + warmup_batches=params.warmup_batches, + warmup_start=params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + datamodule = DataModule(args) + + train_cuts = datamodule.paraspeechcaps_train_base_cuts() + + def remove_short_and_long_utt(c: Any): + # Keep only utterances with duration between 2 second and 30 seconds + if c.duration < 2.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if rank == 0: + duration_bins = datamodule.estimate_duration_bins( + cuts=train_cuts, + world_size=world_size, + rank=rank, + ) + datamodule.args.duration_bins = duration_bins + logging.info(f"Duration bins: {duration_bins}") + if world_size > 1: + obj_list = [duration_bins if rank == 0 else None] + torch.distributed.broadcast_object_list(obj_list, src=0) + datamodule.args.duration_bins = obj_list[0] + + last_upper = 30.0 + datamodule.args.max_seq_len_buckets = datamodule.args.duration_bins + [last_upper] + datamodule.args.fixed_batch_sizes = [ + max(1, int(params.max_duration // ub)) + for ub in datamodule.args.max_seq_len_buckets + ] + + # construct the training dataloader + train_dl = datamodule.train_dataloaders( + train_cuts, + world_size=world_size, + rank=rank, + ) + + valid_sets = [] + valid_dls = [] + if rank == 0: + valid_cuts = datamodule.paraspeechcaps_test_cuts() + valid_sets.append("paraspeechcaps test") + valid_dls.append( + datamodule.valid_dataloaders( + valid_cuts, + world_size=1, + rank=rank, + ) + ) + + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + tokenizer=tokenizer, + optimizer=optimizer, + scheduler=scheduler, + criterion=criterion, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + features = batch["inputs"] + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + DataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + if args.use_multi_node: + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + run(rank=rank, world_size=world_size, args=args) + else: + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/emilia/CLAP/spear_roberta/finetune_stage3.sh b/egs/emilia/CLAP/spear_roberta/finetune_stage3.sh new file mode 100755 index 0000000000..7bcbc1b52f --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/finetune_stage3.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +export PYTHONPATH=/root/icefall:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# export CUDA_VISIBLE_DEVICES=$1 + +lr=0.001 + +# finetune checkpoint +do_finetune=1 +finetune_ckpt=download/stage1-epoch-45-avg-28.pt + +output_ds=2 +post_output_ds=1 + +freeze_encoder=0 +freeze_encoder_steps=-1 +encoder_lr_scale=1 + +md=750 + +exp_dir=spear_roberta/exp_ft + +echo $exp_dir + +if true; then +python spear_roberta/finetune_stage3.py \ + --world-size 8 \ + --num-epochs 400 \ + --use-fp16 0 \ + --use-bf16 1 \ + --start-epoch 1 \ + --exp-dir $exp_dir \ + --manifest-dir data/manifests \ + --base-lr $lr \ + --do-finetune $do_finetune --finetune-ckpt $finetune_ckpt \ + --freeze-encoder $freeze_encoder --freeze-encoder-steps $freeze_encoder_steps \ + --encoder-lr-scale $encoder_lr_scale \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --enable-musan 0 \ + --enable-spec-aug 0 \ + --max-duration $md +fi + +if false; then +epoch=$2 +# avg=$3 +for epoch in $(seq $epoch 5 $((epoch + 24))); do +for avg in $(seq 2 5 $((epoch - 1))); do + python spear_roberta/evaluate_retrieval.py \ + --epoch $epoch \ + --avg $avg \ + --manifest-dir data/manifests \ + --use-averaged-model 1 \ + --downsampling-factor 1,2,4,8,4,2,1 \ + --num-encoder-layers 1,2,3,4,1,1,1 \ + --feedforward-dim 3840,3840,3840,3840,3840,3840,3840 \ + --encoder-dim 1280,1280,1280,1280,1280,1280,1280 \ + --encoder-unmasked-dim 768,768,768,768,768,768,768 \ + --cnn-module-kernel 31,31,15,15,15,31,31 \ + --num-heads 8,8,8,8,8,8,8 \ + --output-downsampling-factor $output_ds \ + --post-encoder-downsampling-factor $post_output_ds \ + --on-the-fly-feats 1 \ + --exp-dir $exp_dir \ + --max-duration $md + done +done +fi + +for i in {0..7}; do CUDA_VISIBLE_DEVICES=$i python /root/busygpu/run.py & done +# python /root/busygpu/run.py & diff --git a/egs/emilia/CLAP/spear_roberta/model.py b/egs/emilia/CLAP/spear_roberta/model.py new file mode 100644 index 0000000000..38765e9e87 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/model.py @@ -0,0 +1,195 @@ +# Copyright 2025 Yifan Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import RobertaModel + +from icefall.utils import make_pad_mask + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class CLAP(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: nn.Module, + encoder_downsample: Optional[nn.Module] = None, + encoder_dim: int = 384, + text_encoder_dim: int = 768, + joint_dim: int = 512, + ): + """A CLAP model. + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + """ + super().__init__() + + # audio branch + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_downsample = encoder_downsample + self.audio_projection = nn.Sequential( + nn.Linear(encoder_dim, joint_dim), + nn.ReLU(), + nn.Linear(joint_dim, joint_dim), + ) + self.audio_transform = MLPLayers( + units=[joint_dim, joint_dim, joint_dim], dropout=0.1 + ) + + # text branch + self.text_encoder = RobertaModel.from_pretrained("roberta-base") + self.text_projection = nn.Sequential( + nn.Linear(text_encoder_dim, joint_dim), + nn.ReLU(), + nn.Linear(joint_dim, joint_dim), + ) + self.text_transform = MLPLayers( + units=[joint_dim, joint_dim, joint_dim], dropout=0.1 + ) + + self.logit_scale = nn.Parameter(torch.full((), math.log(1 / 0.07))) + + def forward_audio_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute audio encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + with torch.set_grad_enabled(not freeze_encoder): + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_out, encoder_out_lens = self.encoder( + x, x_lens, src_key_padding_mask + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + if self.encoder_downsample is not None: + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_downsample(encoder_out) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out_lens = (encoder_out_lens + 1) // 2 + + padding_mask = make_pad_mask(encoder_out_lens) + encoder_out = encoder_out.masked_fill(padding_mask.unsqueeze(-1), 0.0) + embedding = encoder_out.sum(dim=1) / encoder_out_lens.unsqueeze(-1) # (N, C) + + return embedding + + def forward_text_encoder(self, y: dict, freeze_encoder: bool = False): + with torch.set_grad_enabled(not freeze_encoder): + encoder_out = self.text_encoder( + input_ids=y["input_ids"], + attention_mask=y["attention_mask"], + )["pooler_output"] + + return encoder_out + + def forward( + self, + audio: Optional[torch.Tensor] = None, + audio_lens: Optional[torch.Tensor] = None, + text: Optional[dict] = None, + freeze_audio_encoder: bool = False, + freeze_text_encoder: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + audio: + A 3-D tensor of shape (N, T, C). + audio_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + text: + A dict containing the text input ids and attention mask. + Returns: + Return the CLAP loss + """ + if audio is not None: + assert audio.ndim == 3, audio.shape + assert audio_lens.ndim == 1, audio_lens.shape + + audio_encoder_out = self.forward_audio_encoder( + audio, audio_lens, freeze_encoder=freeze_audio_encoder + ) + audio_encoder_out = self.audio_projection(audio_encoder_out) + audio_encoder_out = self.audio_transform(audio_encoder_out) + audio_encoder_out = F.normalize(audio_encoder_out, dim=-1) + + if text is not None: + assert text["input_ids"].ndim == 2, text["input_ids"].shape + + text_encoder_out = self.forward_text_encoder( + text, freeze_encoder=freeze_text_encoder + ) + text_encoder_out = self.text_projection(text_encoder_out) + text_encoder_out = self.text_transform(text_encoder_out) + text_encoder_out = F.normalize(text_encoder_out, dim=-1) + + return ( + audio_encoder_out if audio is not None else None, + text_encoder_out if text is not None else None, + self.logit_scale.exp(), + ) diff --git a/egs/emilia/CLAP/spear_roberta/optim.py b/egs/emilia/CLAP/spear_roberta/optim.py new file mode 120000 index 0000000000..f2c7f05899 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/optim.py @@ -0,0 +1 @@ +../spear/optim.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear_roberta/scaling.py b/egs/emilia/CLAP/spear_roberta/scaling.py new file mode 120000 index 0000000000..6f398f431d --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear_roberta/subsampling.py b/egs/emilia/CLAP/spear_roberta/subsampling.py new file mode 120000 index 0000000000..01ae9002c6 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/emilia/CLAP/spear_roberta/test_fixedbucketbatchsizeconstraint.py b/egs/emilia/CLAP/spear_roberta/test_fixedbucketbatchsizeconstraint.py new file mode 100755 index 0000000000..079a81b947 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/test_fixedbucketbatchsizeconstraint.py @@ -0,0 +1,129 @@ +import torch +from asr_datamodule import _SeedWorkers +from lhotse import Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + DynamicBucketingSampler, + K2SpeechRecognitionDataset, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from torch.utils.data import DataLoader + + +def _test(rank, world_size, args, logs): + cuts = load_manifest_lazy(args["manifest_path"]).filter( + lambda c: 1 <= c.duration <= 20.0 + ) + + constraint = FixedBucketBatchSizeConstraint( + max_seq_len_buckets=args["max_seq_len_buckets"], + batch_sizes=args["fixed_batch_sizes"], + ) + + sampler = DynamicBucketingSampler( + cuts, + constraint=constraint, + shuffle=True, + drop_last=True, + duration_bins=args["duration_bins"], + buffer_size=args["buffer_size"], + world_size=world_size, + rank=rank, + sync_buckets=True, + concurrent=False, + ) + + dataset = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=128))), + return_cuts=True, + ) + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + dl = DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=8, + persistent_workers=True, + pin_memory=True, + prefetch_factor=16, + worker_init_fn=worker_init_fn, + ) + + for i, batch in enumerate(dl): + cuts_in_batch = batch["supervisions"]["cut"] + bs = len(cuts_in_batch) + c0 = cuts_in_batch[0] + bucket = constraint.select_bucket(constraint.max_seq_len_buckets, example=c0) + shape = batch["inputs"].shape + print( + f"[rank {rank}/{world_size}] Step {i}, batch size={bs}, bucket={bucket}, shape={shape}", + flush=True, + ) + + logs[rank].append((i, bs, bucket)) + + +if __name__ == "__main__": + from multiprocessing import Manager + + import torch.multiprocessing as mp + + max_duration = 1000 + world_size = 8 + seed = 42 + num_buckets = 30 + buffer_size = num_buckets * 5000 + manifest_path = "data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz" + + cuts = load_manifest_lazy(manifest_path).filter(lambda c: 1 <= c.duration <= 20.0) + + dummy_sampler = DynamicBucketingSampler( + cuts, + max_duration=max_duration, + num_buckets=num_buckets, + shuffle=True, + drop_last=True, + buffer_size=buffer_size, + world_size=world_size, + rank=0, + seed=seed, + sync_buckets=True, + concurrent=False, + ) + duration_bins = dummy_sampler.duration_bins + del dummy_sampler + + last_upper = 20.0 # + 1e-6 + max_seq_len_buckets = duration_bins + [last_upper] + fixed_batch_sizes = [max(1, int(max_duration // ub)) for ub in max_seq_len_buckets] + + args = dict( + manifest_path=manifest_path, + duration_bins=duration_bins, + max_seq_len_buckets=max_seq_len_buckets, + fixed_batch_sizes=fixed_batch_sizes, + buffer_size=buffer_size, + seed=seed, + ) + + manager = Manager() + logs = manager.dict({r: manager.list() for r in range(world_size)}) + + mp.spawn(_test, args=(world_size, args, logs), nprocs=world_size, join=True) + + steps_list = [len(logs[r]) for r in range(world_size)] + assert len(set(steps_list)) == 1, f"total steps mismatch across ranks: {steps_list}" + total_steps = steps_list[0] + + for s in range(total_steps): + batch_sizes = [logs[r][s][1] for r in range(world_size)] + buckets = [logs[r][s][2] for r in range(world_size)] + print(f"step {s}: batch size {batch_sizes}, bucket {buckets}") + assert ( + len(set(batch_sizes)) == 1 + ), f"step {s}: batch size mismatch: {batch_sizes}" + assert len(set(buckets)) == 1, f"step {s}: bucket mismatch: {buckets}" + + print(f"Done: verified {total_steps} steps") diff --git a/egs/emilia/CLAP/spear_roberta/zipformer2.py b/egs/emilia/CLAP/spear_roberta/zipformer2.py new file mode 120000 index 0000000000..2c9c437fa8 --- /dev/null +++ b/egs/emilia/CLAP/spear_roberta/zipformer2.py @@ -0,0 +1 @@ +../spear/zipformer2.py \ No newline at end of file diff --git a/icefall/utils.py b/icefall/utils.py index 0d4e24db53..a373d94c56 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -630,8 +630,8 @@ def store_transcripts_and_timestamps( def store_translations( - filename: Pathlike, texts: Iterable[Tuple[str, str, str]], - lowercase: bool = True) -> None: + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], lowercase: bool = True +) -> None: """Save predicted results and reference transcripts to a file. Args: @@ -648,11 +648,13 @@ def store_translations( hyp_list = [] ref_list = [] dir_ = os.path.dirname(filename) - reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) - refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename))) - hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename))) - bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename))) - with open(filename, "w") as f, open(reftgt, "w") as f_tgt, open(hyp, "w") as f_hyp, open(refsrc, "w") as f_src: + reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) + refsrc = os.path.join(dir_, "refsrc-" + str(os.path.basename(filename))) + hyp = os.path.join(dir_, "hyp-" + str(os.path.basename(filename))) + bleu_file = os.path.join(dir_, "bleu-" + str(os.path.basename(filename))) + with open(filename, "w") as f, open(reftgt, "w") as f_tgt, open( + hyp, "w" + ) as f_hyp, open(refsrc, "w") as f_src: for cut_id, ref, ref_tgt, hyp in texts: ref = " ".join(ref) ref_tgt = " ".join(ref_tgt) @@ -661,7 +663,6 @@ def store_translations( print(f"{cut_id}: ref_tgt {ref_tgt}", file=f) print(f"{cut_id}: hyp {hyp}", file=f) print("\n", file=f) - print(f"{ref}", file=f_src) print(f"{ref_tgt}", file=f_tgt) @@ -670,14 +671,14 @@ def store_translations( hyp_list.append(hyp) ref_list.append(ref_tgt) - with open(bleu_file, 'w') as b: + with open(bleu_file, "w") as b: print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b) print(f"BLEU signiture: {str(bleu.get_signature())}", file=b) - + logging.info( - f"[{bleu.corpus_score(hyp_list, [ref_list])}] " - f"BLEU signiture: {str(bleu.get_signature())}" - ) + f"[{bleu.corpus_score(hyp_list, [ref_list])}] " + f"BLEU signiture: {str(bleu.get_signature())}" + ) def write_error_stats( @@ -1289,14 +1290,14 @@ def __str__(self) -> str: ans_utterances += str(k) + "=" + str(norm_value) if k == "utt_duration": ans_utterances += " frames, " - elif k == "utt_pad_proportion": - ans_utterances += ", " else: - raise ValueError(f"Unexpected key: {k}") - frames = "%.2f" % self["frames"] - ans_frames += "over " + str(frames) + " frames. " + ans_utterances += ", " + + if "frames" in self: + frames = "%.2f" % self.get("frames", 0) + ans_frames += "over " + str(frames) + " frames. " if ans_utterances != "": - utterances = "%.2f" % self["utterances"] + utterances = "%.2f" % self.get("utterances", 0) ans_utterances += "over " + str(utterances) + " utterances." return ans_frames + ans_utterances @@ -1306,8 +1307,8 @@ def norm_items(self) -> List[Tuple[str, float]]: Returns a list of pairs, like: [('ctc_loss', 0.1), ('att_loss', 0.07)] """ - num_frames = self["frames"] if "frames" in self else 1 - num_utterances = self["utterances"] if "utterances" in self else 1 + num_frames = self.get("frames", 1) + num_utterances = self.get("utterances", 1) ans = [] for k, v in self.items(): if k == "frames" or k == "utterances": @@ -2454,3 +2455,12 @@ def time_warp( ) return features + + +def compare_model(state_dict1, state_dict2): + assert state_dict1.keys() == state_dict2.keys() + for key in state_dict1.keys(): + if torch.all(state_dict1[key] == state_dict2[key]): + logging.info(f"Param: {key} is the same as new state dict") + else: + logging.info(f"Param: {key} is updated from new state dict") diff --git a/requirements.txt b/requirements.txt index 885bf2fc3d..5ca60962d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,11 +8,6 @@ pypinyin==0.50.0 tensorboard typeguard dill -onnx>=1.15.0 -onnxruntime>=1.16.3 -onnxoptimizer -onnxsim -onnxconverter_common # style check session: black==22.3.0