Self Checks
1. Is this request related to a challenge you're experiencing? Tell us your story.
Fish Speech currently only supports NVIDIA CUDA and CPU backends for inference. AMD GPU users (both consumer RDNA and datacenter Instinct) have no official path to run Fish Speech with GPU acceleration.
With AMD's ROCm stack maturing (7.2 released with PyTorch 2.11.0 support) and the growing availability of high-VRAM AMD cards (RX 7900 XTX 24GB, RX 9070 XT 16GB, MI300X), there's a meaningful user base that could benefit from ROCm support.
Key challenges identified while prototyping ROCm support:
- The Dockerfile only provides CUDA and CPU base stages
- Docker Compose has no GPU passthrough configuration for AMD (
/dev/kfd, /dev/dri)
- Several code paths use hardcoded
.cuda() calls instead of device-agnostic alternatives
- The VQ-GAN decoder loads in float32 regardless of the requested precision, consuming ~1.87GB instead of ~1.39GB in bfloat16 — problematic for 16GB cards
- KV cache is always allocated at
max_seq_len=32768, which combined with model weights exceeds 16GB VRAM
pyproject.toml lacks a rocm72 dependency extra for the ROCm PyTorch index
2. What is your suggested solution?
- Dockerfile: Add a
base-rocm stage using rocm/dev-ubuntu-24.04:7.2 as the base image
- Docker Compose: Add a
compose.rocm.yml overlay that maps /dev/kfd and /dev/dri, sets render group, and exposes ROCm-specific env vars (HSA_ENABLE_SDMA, GPU_MAX_HW_QUEUES, PYTORCH_HIP_ALLOC_CONF)
- pyproject.toml: Add a
rocm72 extra pointing to PyTorch's ROCm 7.2 index (https://download.pytorch.org/whl/rocm7.2)
- Device-agnostic code: Replace hardcoded
.cuda() calls with .to(device) in modded_dac.py, extract_vq.py, and model_utils.py
- VQ-GAN precision: Pass the user-requested precision (bfloat16/float16) through to
load_model() in dac/inference.py so the decoder benefits from the same memory savings as the Llama model
- Configurable KV cache: Respect a
MAX_SEQ_LEN environment variable to allow users with constrained VRAM to reduce KV cache allocation (default remains 32768)
- Optional VRAM cap: Respect a
VRAM_FRACTION environment variable that calls torch.cuda.set_per_process_memory_fraction() to prevent driver-level freezes on cards where VRAM overcommit is fatal
Tested successfully on: RX 9070 (16GB, gfx1201/RDNA4), Ubuntu 24.04, kernel 6.17 + amdgpu-dkms 6.18.4, ROCm 7.2, PyTorch 2.11.0.
3. Additional context or comments
Usage:
BACKEND=rocm UV_EXTRA=rocm72 VRAM_FRACTION=0.95 MAX_SEQ_LEN=4096 \
docker compose -f compose.yml -f compose.rocm.yml --profile webui up
A working implementation is available and a PR will follow this issue.
4. Can you help us with this feature?
Self Checks
1. Is this request related to a challenge you're experiencing? Tell us your story.
Fish Speech currently only supports NVIDIA CUDA and CPU backends for inference. AMD GPU users (both consumer RDNA and datacenter Instinct) have no official path to run Fish Speech with GPU acceleration.
With AMD's ROCm stack maturing (7.2 released with PyTorch 2.11.0 support) and the growing availability of high-VRAM AMD cards (RX 7900 XTX 24GB, RX 9070 XT 16GB, MI300X), there's a meaningful user base that could benefit from ROCm support.
Key challenges identified while prototyping ROCm support:
/dev/kfd,/dev/dri).cuda()calls instead of device-agnostic alternativesmax_seq_len=32768, which combined with model weights exceeds 16GB VRAMpyproject.tomllacks arocm72dependency extra for the ROCm PyTorch index2. What is your suggested solution?
base-rocmstage usingrocm/dev-ubuntu-24.04:7.2as the base imagecompose.rocm.ymloverlay that maps/dev/kfdand/dev/dri, sets render group, and exposes ROCm-specific env vars (HSA_ENABLE_SDMA,GPU_MAX_HW_QUEUES,PYTORCH_HIP_ALLOC_CONF)rocm72extra pointing to PyTorch's ROCm 7.2 index (https://download.pytorch.org/whl/rocm7.2).cuda()calls with.to(device)inmodded_dac.py,extract_vq.py, andmodel_utils.pyload_model()indac/inference.pyso the decoder benefits from the same memory savings as the Llama modelMAX_SEQ_LENenvironment variable to allow users with constrained VRAM to reduce KV cache allocation (default remains 32768)VRAM_FRACTIONenvironment variable that callstorch.cuda.set_per_process_memory_fraction()to prevent driver-level freezes on cards where VRAM overcommit is fatalTested successfully on: RX 9070 (16GB, gfx1201/RDNA4), Ubuntu 24.04, kernel 6.17 + amdgpu-dkms 6.18.4, ROCm 7.2, PyTorch 2.11.0.
3. Additional context or comments
Usage:
A working implementation is available and a PR will follow this issue.
4. Can you help us with this feature?