Skip to content

Support for AMD ROCm Devices #1246

@gtherond

Description

@gtherond

Self Checks

  • I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find any relevant information that meets my needs.
  • I have searched for existing issues, including closed ones.
  • I confirm that I am using English to submit this report.
  • [FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)
  • Please do not modify this template :) and fill in all the required fields.

1. Is this request related to a challenge you're experiencing? Tell us your story.

Fish Speech currently only supports NVIDIA CUDA and CPU backends for inference. AMD GPU users (both consumer RDNA and datacenter Instinct) have no official path to run Fish Speech with GPU acceleration.

With AMD's ROCm stack maturing (7.2 released with PyTorch 2.11.0 support) and the growing availability of high-VRAM AMD cards (RX 7900 XTX 24GB, RX 9070 XT 16GB, MI300X), there's a meaningful user base that could benefit from ROCm support.

Key challenges identified while prototyping ROCm support:

  • The Dockerfile only provides CUDA and CPU base stages
  • Docker Compose has no GPU passthrough configuration for AMD (/dev/kfd, /dev/dri)
  • Several code paths use hardcoded .cuda() calls instead of device-agnostic alternatives
  • The VQ-GAN decoder loads in float32 regardless of the requested precision, consuming ~1.87GB instead of ~1.39GB in bfloat16 — problematic for 16GB cards
  • KV cache is always allocated at max_seq_len=32768, which combined with model weights exceeds 16GB VRAM
  • pyproject.toml lacks a rocm72 dependency extra for the ROCm PyTorch index

2. What is your suggested solution?

  • Dockerfile: Add a base-rocm stage using rocm/dev-ubuntu-24.04:7.2 as the base image
  • Docker Compose: Add a compose.rocm.yml overlay that maps /dev/kfd and /dev/dri, sets render group, and exposes ROCm-specific env vars (HSA_ENABLE_SDMA, GPU_MAX_HW_QUEUES, PYTORCH_HIP_ALLOC_CONF)
  • pyproject.toml: Add a rocm72 extra pointing to PyTorch's ROCm 7.2 index (https://download.pytorch.org/whl/rocm7.2)
  • Device-agnostic code: Replace hardcoded .cuda() calls with .to(device) in modded_dac.py, extract_vq.py, and model_utils.py
  • VQ-GAN precision: Pass the user-requested precision (bfloat16/float16) through to load_model() in dac/inference.py so the decoder benefits from the same memory savings as the Llama model
  • Configurable KV cache: Respect a MAX_SEQ_LEN environment variable to allow users with constrained VRAM to reduce KV cache allocation (default remains 32768)
  • Optional VRAM cap: Respect a VRAM_FRACTION environment variable that calls torch.cuda.set_per_process_memory_fraction() to prevent driver-level freezes on cards where VRAM overcommit is fatal

Tested successfully on: RX 9070 (16GB, gfx1201/RDNA4), Ubuntu 24.04, kernel 6.17 + amdgpu-dkms 6.18.4, ROCm 7.2, PyTorch 2.11.0.

3. Additional context or comments

Usage:

BACKEND=rocm UV_EXTRA=rocm72 VRAM_FRACTION=0.95 MAX_SEQ_LEN=4096 \
  docker compose -f compose.yml -f compose.rocm.yml --profile webui up

A working implementation is available and a PR will follow this issue.

4. Can you help us with this feature?

  • I am interested in contributing to this feature.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions