ROCm support scripts in the future?
#2
by choc-chilla - opened
Will it be a long wait? Poor man's AI user here <3
I got it working on ROCm. Here is the setup:
Hardware tested: RX 7900 XTX, ROCm 7.2.3, PyTorch 2.11+rocm7.2
Key findings:
- Flash Attention is NOT required. SA3 falls back to Triton flex_attention kernels that work on ROCm out of the box.
- Relax the strict torch==2.7.1 pin in pyproject.toml (ROCm wheels are 2.11+). Change to torch>=2.7.1.
- VRAM: ~5GB for Medium. First run has 4s Triton autotune, then sub-second generations.
Setup:
python3 -m venv sa3-venv && source sa3-venv/bin/activate
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2
git clone https://github.com/Stability-AI/stable-audio-3
cd stable-audio-3
sed -i "s/torch==2.7.1/torch>=2.7.1/g; s/torchaudio==2.7.1/torchaudio>=2.7.1/g" pyproject.toml
pip install -e .
Inference:
import os, soundfile as sf
from stable_audio_3 import StableAudioModel
os.environ["HIP_VISIBLE_DEVICES"] = "0"
model = StableAudioModel.from_pretrained("medium")
audio = model.generate("epic orchestral fanfare", duration=30)
sf.write("output.wav", audio.detach().cpu().squeeze(0).T.numpy(), 44100)
For ROCm 6.x, use https://download.pytorch.org/whl/rocm6.2. Works on RDNA2 (6700 XT, etc.) and RDNA3 (7900 XTX).
Something like this when it's all completed:
sa3_rocm_setup.sh
#!/bin/bash
# ============================================================
# Stable Audio 3 Medium β ROCm (AMD GPU) Setup
# Tested: RX 7900 XTX, ROCm 7.2.3, PyTorch 2.11+rocm7.2
# Works on RDNA2 (gfx1030) and RDNA3 (gfx1100) cards
# ============================================================
set -e
echo "=== Stable Audio 3 ROCm Setup ==="
# 1. FIND YOUR GPU ARCH
echo "Checking GPU..."
GFX=$(rocminfo 2>/dev/null | grep "gfx" | head -1 | grep -oP 'gfx\w+')
echo " GPU arch: ${GFX:-unknown}"
if command -v rocm-smi &>/dev/null; then
rocm-smi --showproductname 2>/dev/null | grep "Card Series" || true
fi
# 2. CREATE VENV
VENV_DIR="${1:-./sa3-venv}"
echo "Creating venv at: $VENV_DIR"
python3 -m venv "$VENV_DIR"
source "$VENV_DIR/bin/activate"
pip install --upgrade pip -q
# 3. INSTALL PYTORCH FOR ROCM
# Check which ROCm version is installed
ROCM_VER=$(hipconfig --version 2>/dev/null | grep -oP 'HIP version: \K[0-9]+\.[0-9]+' | head -1 || echo "6.2")
echo " ROCm version: $ROCM_VER"
PYTORCH_INDEX="https://download.pytorch.org/whl/rocm${ROCM_VER}"
echo "Installing PyTorch from $PYTORCH_INDEX ..."
pip install torch torchvision torchaudio --index-url "$PYTORCH_INDEX"
# Verify
python3 -c "
import torch
print(f' PyTorch: {torch.__version__}')
print(f' GPU: {torch.cuda.get_device_name(0)}')
print(f' VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
"
# 4. CLONE AND PATCH SA3
if [ ! -d "stable-audio-3" ]; then
echo "Cloning stable-audio-3..."
git clone https://github.com/Stability-AI/stable-audio-3
fi
cd stable-audio-3
# Relax strict torch version pins (SA3 requires torch==2.7.1, ROCm has 2.11+)
sed -i.bak 's/torch==2.7.1/torch>=2.7.1/g; s/torchaudio==2.7.1/torchaudio>=2.7.1/g' pyproject.toml
echo " Patched pyproject.toml (relaxed torch pins)"
# 5. INSTALL SA3
echo "Installing stable-audio-3..."
pip install -e . -q
# 6. (OPTIONAL) FLASH ATTENTION
# Flash Attention is NOT required β SA3 has graceful fallback.
# The ROCm fork may fail on newer ROCm versions. Skip if it fails.
echo "Skipping flash-attn (optional, SA3 works without it)"
echo " To try: pip install --no-build-isolation git+https://github.com/ROCm/flash-attention@howiejay/navi_support"
echo ""
echo "=== Setup Complete ==="
echo ""
echo "Next steps:"
echo " 1. Accept model terms: https://huggingface.co/stabilityai/stable-audio-3-medium"
echo " 2. Set your HF token: export HF_TOKEN='hf_...'"
echo " 3. Run inference (see sa3_inference.py)"
The then actual inference:
sa3_inference.py
"""
Stable Audio 3 Medium β ROCm (AMD GPU) Inference Script
Tested: RX 7900 XTX, ROCm 7.2.3, PyTorch 2.11+rocm7.2
Works without flash-attn (graceful fallback to Triton kernels)
"""
import torch
import os
import soundfile as sf
from stable_audio_3 import StableAudioModel
# === CONFIGURATION ===
# Pin to a specific GPU (0 = first, 1 = second, etc.)
GPU_ID = "0"
os.environ["HIP_VISIBLE_DEVICES"] = GPU_ID
# Model: "small-music", "small-sfx", or "medium"
MODEL = "medium"
# Generate this many seconds of audio (max 380 for medium, 120 for small)
DURATION = 30
# Your prompt β be descriptive about instruments, atmosphere, space, tempo
PROMPT = (
"dark ambient drone with deep sub-bass, "
"haunting ethereal female choir pads, "
"slow cinematic atmosphere, 60 BPM"
)
# Output file
OUTPUT = "sa3_output.wav"
# === INFERENCE ===
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM free: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")
print(f"Loading {MODEL} model...")
model = StableAudioModel.from_pretrained(MODEL)
print("Model loaded.")
print(f"Generating {DURATION}s audio...")
audio = model.generate(PROMPT, duration=DURATION)
# Convert to numpy and save
# audio shape: (1, channels, samples) β e.g. (1, 2, 1323000) for 30s stereo
audio_np = audio.detach().cpu().squeeze(0).T.numpy() # β (samples, channels)
sf.write(OUTPUT, audio_np, 44100)
print(f"Saved: {OUTPUT}")
print(f"Duration: {len(audio_np)/44100:.1f}s, Size: {os.path.getsize(OUTPUT):,} bytes")
print("Done.")
Good luck. Long live ROCm!