Skip to content

[export] aot.export fails when model contains bfloat16 tensors (unsupported ScalarType BFloat16) #1127

@eomiso

Description

@eomiso

When exporting a Hugging Face Llama 3.1 model using iree-turbine AOT API, the exporter crashes with:

TypeError: Got unsupported ScalarType BFloat16

The error arises when the exporter tries to convert a PyTorch bfloat16 tensor into a NumPy array inside iree/turbine/aot/support/ir_utils.py:create_tensor_global:

array = np.array(detached_tensor)  # torch._tensor.py → .numpy() fails

Numpy does not have native support for bfloat16, so the conversion fails.

Steps to reproduce

import torch
import iree.turbine.aot as aot
from transformers import AutoModelForCausalLM

MODEL = "/Users/.../Llama3.1-8B-Instruct-hf"

class OneStep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.m = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype="auto")
        self.m.config.sliding_window = None
        self.m.eval()

    def forward(self, input_ids, attention_mask):
        out = self.m(input_ids=input_ids, attention_mask=attention_mask.bool(), use_cache=False)
        return out.logits[:, -1, :]

mod = OneStep()
ex_ids = torch.empty(1, 64, dtype=torch.int16)
ex_mask = torch.empty(1, 64, dtype=torch.int16)

export = aot.export(mod, ex_ids, ex_mask)  # crash here

Environment

  • macOS (Apple Silicon)
  • Python 3.12
  • torch 2.7.0
  • iree-base-compiler 3.6.0
  • iree-base-runtime 3.6.0
  • iree-turbine 3.6.0

Workarounds

For now, forcing the model to load with torch_dtype="float32" or "float16" avoids the crash:

self.m = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype="float16")

However, it’s unclear whether compiling with float16 introduces correctness issues for these models. Some Llama checkpoints ship in bfloat16 specifically to preserve accuracy, so I wonder if a silent cast to float16 may degrade results. It would be better if Turbine either:

  • supported bfloat16 directly end-to-end, or
  • documented the implications of automatic downcasting.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions