-
Notifications
You must be signed in to change notification settings - Fork 78
Open
Description
When exporting a Hugging Face Llama 3.1 model using iree-turbine AOT API, the exporter crashes with:
TypeError: Got unsupported ScalarType BFloat16
The error arises when the exporter tries to convert a PyTorch bfloat16 tensor into a NumPy array inside iree/turbine/aot/support/ir_utils.py:create_tensor_global:
array = np.array(detached_tensor) # torch._tensor.py → .numpy() failsNumpy does not have native support for bfloat16, so the conversion fails.
Steps to reproduce
import torch
import iree.turbine.aot as aot
from transformers import AutoModelForCausalLM
MODEL = "/Users/.../Llama3.1-8B-Instruct-hf"
class OneStep(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype="auto")
self.m.config.sliding_window = None
self.m.eval()
def forward(self, input_ids, attention_mask):
out = self.m(input_ids=input_ids, attention_mask=attention_mask.bool(), use_cache=False)
return out.logits[:, -1, :]
mod = OneStep()
ex_ids = torch.empty(1, 64, dtype=torch.int16)
ex_mask = torch.empty(1, 64, dtype=torch.int16)
export = aot.export(mod, ex_ids, ex_mask) # crash here
Environment
- macOS (Apple Silicon)
- Python 3.12
- torch 2.7.0
- iree-base-compiler 3.6.0
- iree-base-runtime 3.6.0
- iree-turbine 3.6.0
Workarounds
For now, forcing the model to load with torch_dtype="float32" or "float16" avoids the crash:
self.m = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype="float16")However, it’s unclear whether compiling with float16 introduces correctness issues for these models. Some Llama checkpoints ship in bfloat16 specifically to preserve accuracy, so I wonder if a silent cast to float16 may degrade results. It would be better if Turbine either:
- supported bfloat16 directly end-to-end, or
- documented the implications of automatic downcasting.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels