Skip to content

Suppress tokens during training #1

@pprobst

Description

@pprobst

Hello. First of all, thank you for your work.

I came upon this repo when trying to improve the transcription speed in whispercpp by using a lower audio_ctx.
However, while fine-tuning with this code, it does not seem to be suppressing tokens. I adapted the code a bit to set suppress_tokens and to use my dataset in Spanish, but everything else of importance remained unchanged.

When fine-tuning normally, at least for WhisperForConditionalGeneration, setting model.config.suppress_tokens = suppress works. But I'm not sure it's working here. Furthermore, my training dataset (about 6100 audio files) does not use any punctuation marks, that is, "," is "comma" and not the actual character. So I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here -- even if I train for 15 epochs -- the same number of epochs I use in my "normal" pipeline that does not use dynamic audio context.

Also, in the code below, suppressing tokens only works during inference if I set these in the model.generate function:

model.model = model_train.eval().cuda()
predicted_ids_train = model.generate(
    input_features,
    suppress_tokens=suppress,
    forced_decoder_ids=processor.get_decoder_prompt_ids(
        language=language, task="transcribe"
    ),
)

This made me suspect that it's not suppressing tokens during training. But I don't know how to verify this.

Full code (except loading the dataset, but nothing unusual there):

#!/usr/bin/env python3

import shutil
import torch

from datasets import load_dataset
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch import nn
from transformers import (
    WhisperModel,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
)
from dataset import CSVDataset
from sys import argv

EPOCHS = 3
SEED = 42
SUPPRESS_TOKENS_TRAIN = "0123456789@#%&*+=_$€:-.,?¿;!¡"

languages = {
    "spanish": "es",
    "portuguese": "pt",
    "es": "es",
    "pt": "pt",
}

MODEL = argv[1]
language = languages[argv[2].lower()]
data_path = Path(argv[3])


def get_suppress_tokens(tokenizer: WhisperTokenizer, tokens: str):
    suppressed_tokens = [
        i
        for i in range(tokenizer.eos_token_id)
        if all(c in tokens for c in tokenizer.decode([i]).removeprefix(" "))
    ]
    return suppressed_tokens


processor = WhisperProcessor.from_pretrained(
    f"openai/whisper-{MODEL}", task="transcribe", language=language
)


model_train = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}")
model_base = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}")

suppress = []
suppress = [-1] + get_suppress_tokens(processor.tokenizer, SUPPRESS_TOKENS_TRAIN)
model_train.config.suppress_tokens = suppress
model_base.config.suppress_tokens = suppress
model_train.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)
model_base.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)

model_train = model_train.train().cuda()
model_base = model_base.eval().cuda()


ds_full = CSVDataset(
    data_path=data_path,
    processor=processor,
    add_silence=False,
    ignore_datasets=[],
    audio_augs=True,
)

ds = ds_full.dataset["train"]

# ds = load_dataset("google/fleurs", "en_us", split="train")


def get_sample(example):
    waveform = example["audio"]["array"]
    sampling_rate = example["audio"]["sampling_rate"]

    # Use the model and processor to transcribe the audio:
    input_features = processor(
        waveform, sampling_rate=sampling_rate, return_tensors="pt"
    ).input_features

    return {
        "length": len(waveform) / sampling_rate,
        "input_features": input_features,
        "input_ids": processor.tokenizer.encode(example["transcript"].lower()),
    }


# if not (".en" in MODEL):
#    print(processor.get_decoder_prompt_ids(language="english", task="transcribe"))
# [processor.tokenizer.decode(i) for i in get_sample(ds[1])["input_ids"]]


def compute_partially_encoder(model, data, n_audio_ctx):
    diffy = 2 * n_audio_ctx - data.shape[2]

    if diffy > 0:
        data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0)
    elif diffy < 0:
        data = data[:, :, :diffy]

    if n_audio_ctx == 1500:
        return model.encoder(data).last_hidden_state

    input_embeds = nn.functional.gelu(model.encoder.conv1(data))
    input_embeds = nn.functional.gelu(model.encoder.conv2(input_embeds))
    input_embeds = input_embeds.permute(0, 2, 1)

    embed_pos = model.encoder.embed_positions.weight[:n_audio_ctx]

    hidden_states = input_embeds + embed_pos
    hidden_states = nn.functional.dropout(
        hidden_states, p=model.encoder.dropout, training=model.encoder.training
    )

    for idx, encoder_layer in enumerate(model.encoder.layers):
        to_drop = False
        if model.encoder.training:
            dropout_probability = torch.rand([])
            if dropout_probability < model.encoder.layerdrop:
                to_drop = True

        if to_drop:
            layer_outputs = (None, None)
        else:
            if model.encoder.gradient_checkpointing and model.encoder.training:
                layer_outputs = model.encoder._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    None,
                    None,
                    False,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    None,
                    layer_head_mask=None,
                    output_attentions=False,
                )

            hidden_states = layer_outputs[0]

    hidden_states = model.encoder.layer_norm(hidden_states)
    return hidden_states


def compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example):
    optimizer.zero_grad()

    n_ctx = int(round((1500.0 / 30.0) * example["length"]))

    extra_ctx = torch.randint(-min(64, n_ctx // 3), min(64, n_ctx // 3), (1,)).item()
    n_ctx += extra_ctx

    input_features = example["input_features"].cuda()
    input_ids = torch.tensor([example["input_ids"]], dtype=torch.long).cuda()

    encoder_hidden_states_partial = compute_partially_encoder(
        model_train, input_features, n_ctx
    )
    output_partial = model_train.decoder(
        input_ids=input_ids,
        encoder_hidden_states=encoder_hidden_states_partial,
        output_hidden_states=True,
    )

    with torch.no_grad():
        encoder_hidden_states_full = compute_partially_encoder(
            model_base, input_features, 1500
        )
        output_full = model_base.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states_full,
            output_hidden_states=True,
        )
        # print(output_full)

    loss = criterion(
        # output_partial.hidden_states[-1],
        # output_full.hidden_states[-1]
        torch.cat(output_partial.hidden_states, 0),
        torch.cat(output_full.hidden_states, 0),
    )

    loss.backward()
    optimizer.step()

    return loss


criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model_train.parameters(), lr=1e-6)


writer = SummaryWriter()
writer.add_text("name", f"{MODEL} v3")

num_length = 0
step = 0
for epoch in range(EPOCHS):
    pbar = tqdm(ds.shuffle(seed=SEED))
    for example in pbar:
        example = get_sample(example)
        if example["length"] > 29.0:
            continue

        loss = compute_hidden_state_loss(
            model_train, model_base, optimizer, criterion, example
        )
        step += 1
        num_length += example["length"]

        writer.add_scalar("loss/train", loss.item(), step)
        writer.add_scalar("length/train", num_length, step)
        writer.add_scalar("epoch/train", epoch, step)

        pbar.set_description(f"Epoch {epoch}, Loss: {loss.item()}")


# Select an audio file and read it:
# ds_eval = load_dataset(
#    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
# )
ds_eval = ds_full.dataset["val"]

# Load the Whisper model in Hugging Face format:
model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}")

model.config.suppress_tokens = suppress
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)
model = model.eval().cuda()

for i in range(len(ds_eval)):
    audio_sample = ds_eval[i]["audio"]
    waveform = audio_sample["array"]
    sampling_rate = audio_sample["sampling_rate"]
    # print(ds_eval[i])
    # input_features = ds_eval[i]["input_features"].cuda()

    # Use the model and processor to transcribe the audio:
    input_features = processor(
        waveform, sampling_rate=sampling_rate, return_tensors="pt"
    ).input_features.cuda()

    model.model = model_base.eval().cuda()
    # Suppress tokens only work here, not when setting in the model.config. Why?
    predicted_ids_base = model.generate(
        input_features,
        suppress_tokens=suppress,
        forced_decoder_ids=processor.get_decoder_prompt_ids(
            language=language, task="transcribe"
        ),
    )
    model.model = model_train.eval().cuda()
    predicted_ids_train = model.generate(
        input_features,
        suppress_tokens=suppress,
        forced_decoder_ids=processor.get_decoder_prompt_ids(
            language=language, task="transcribe"
        ),
    )

    # Decode token ids to text
    transcription = processor.batch_decode(
        [predicted_ids_base[0], predicted_ids_train[0]], skip_special_tokens=True
    )

    # Use self.tokenizer._basic_normalize(pred).strip() to normalize the transcriptions
    transcription = [
        processor.tokenizer._basic_normalize(pred).strip() for pred in transcription
    ]

    print(
        f"\n\nGrndTr: {ds_eval[i]['transcript'].lower()}\nModelB: {transcription[0]}\nModelT: {transcription[1]}"
    )

model = (
    WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}")
    .eval()
    .cpu()
)
model.config.suppress_tokens = suppress
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)
model.model = model_train.eval().cpu()

model.save_pretrained(f"model_train-{MODEL}3")
processor.tokenizer.save_pretrained(f"model_train-{MODEL}3")

shutil.make_archive(f"model_train-{MODEL}3", "zip", f"model_train-{MODEL}3")

Finally, when running inference on a test set using whispercpp and the fine-tuned model, I get a WER of 39% instead of the usual 7% that I get when trained "normally".

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions