ROCm support scripts in the future?

#2
by choc-chilla - opened

Will it be a long wait? Poor man's AI user here <3

I got it working on ROCm. Here is the setup:

Hardware tested: RX 7900 XTX, ROCm 7.2.3, PyTorch 2.11+rocm7.2

Key findings:

  • Flash Attention is NOT required. SA3 falls back to Triton flex_attention kernels that work on ROCm out of the box.
  • Relax the strict torch==2.7.1 pin in pyproject.toml (ROCm wheels are 2.11+). Change to torch>=2.7.1.
  • VRAM: ~5GB for Medium. First run has 4s Triton autotune, then sub-second generations.

Setup:

python3 -m venv sa3-venv && source sa3-venv/bin/activate
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2
git clone https://github.com/Stability-AI/stable-audio-3
cd stable-audio-3
sed -i "s/torch==2.7.1/torch>=2.7.1/g; s/torchaudio==2.7.1/torchaudio>=2.7.1/g" pyproject.toml
pip install -e .

Inference:

import os, soundfile as sf
from stable_audio_3 import StableAudioModel
os.environ["HIP_VISIBLE_DEVICES"] = "0"
model = StableAudioModel.from_pretrained("medium")
audio = model.generate("epic orchestral fanfare", duration=30)
sf.write("output.wav", audio.detach().cpu().squeeze(0).T.numpy(), 44100)

For ROCm 6.x, use https://download.pytorch.org/whl/rocm6.2. Works on RDNA2 (6700 XT, etc.) and RDNA3 (7900 XTX).

Something like this when it's all completed:

sa3_rocm_setup.sh

#!/bin/bash
# ============================================================
# Stable Audio 3 Medium β€” ROCm (AMD GPU) Setup
# Tested: RX 7900 XTX, ROCm 7.2.3, PyTorch 2.11+rocm7.2
# Works on RDNA2 (gfx1030) and RDNA3 (gfx1100) cards
# ============================================================
set -e

echo "=== Stable Audio 3 ROCm Setup ==="

# 1. FIND YOUR GPU ARCH
echo "Checking GPU..."
GFX=$(rocminfo 2>/dev/null | grep "gfx" | head -1 | grep -oP 'gfx\w+')
echo "  GPU arch: ${GFX:-unknown}"
if command -v rocm-smi &>/dev/null; then
    rocm-smi --showproductname 2>/dev/null | grep "Card Series" || true
fi

# 2. CREATE VENV
VENV_DIR="${1:-./sa3-venv}"
echo "Creating venv at: $VENV_DIR"
python3 -m venv "$VENV_DIR"
source "$VENV_DIR/bin/activate"
pip install --upgrade pip -q

# 3. INSTALL PYTORCH FOR ROCM
# Check which ROCm version is installed
ROCM_VER=$(hipconfig --version 2>/dev/null | grep -oP 'HIP version: \K[0-9]+\.[0-9]+' | head -1 || echo "6.2")
echo "  ROCm version: $ROCM_VER"
PYTORCH_INDEX="https://download.pytorch.org/whl/rocm${ROCM_VER}"

echo "Installing PyTorch from $PYTORCH_INDEX ..."
pip install torch torchvision torchaudio --index-url "$PYTORCH_INDEX"

# Verify
python3 -c "
import torch
print(f'  PyTorch: {torch.__version__}')
print(f'  GPU: {torch.cuda.get_device_name(0)}')
print(f'  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
"

# 4. CLONE AND PATCH SA3
if [ ! -d "stable-audio-3" ]; then
    echo "Cloning stable-audio-3..."
    git clone https://github.com/Stability-AI/stable-audio-3
fi
cd stable-audio-3

# Relax strict torch version pins (SA3 requires torch==2.7.1, ROCm has 2.11+)
sed -i.bak 's/torch==2.7.1/torch>=2.7.1/g; s/torchaudio==2.7.1/torchaudio>=2.7.1/g' pyproject.toml
echo "  Patched pyproject.toml (relaxed torch pins)"

# 5. INSTALL SA3
echo "Installing stable-audio-3..."
pip install -e . -q

# 6. (OPTIONAL) FLASH ATTENTION
# Flash Attention is NOT required β€” SA3 has graceful fallback.
# The ROCm fork may fail on newer ROCm versions. Skip if it fails.
echo "Skipping flash-attn (optional, SA3 works without it)"
echo "  To try: pip install --no-build-isolation git+https://github.com/ROCm/flash-attention@howiejay/navi_support"

echo ""
echo "=== Setup Complete ==="
echo ""
echo "Next steps:"
echo "  1. Accept model terms: https://huggingface.co/stabilityai/stable-audio-3-medium"
echo "  2. Set your HF token:  export HF_TOKEN='hf_...'"
echo "  3. Run inference (see sa3_inference.py)"

The then actual inference:

sa3_inference.py

"""
Stable Audio 3 Medium β€” ROCm (AMD GPU) Inference Script
Tested: RX 7900 XTX, ROCm 7.2.3, PyTorch 2.11+rocm7.2
Works without flash-attn (graceful fallback to Triton kernels)
"""
import torch
import os
import soundfile as sf
from stable_audio_3 import StableAudioModel

# === CONFIGURATION ===
# Pin to a specific GPU (0 = first, 1 = second, etc.)
GPU_ID = "0"
os.environ["HIP_VISIBLE_DEVICES"] = GPU_ID

# Model: "small-music", "small-sfx", or "medium"
MODEL = "medium"

# Generate this many seconds of audio (max 380 for medium, 120 for small)
DURATION = 30

# Your prompt β€” be descriptive about instruments, atmosphere, space, tempo
PROMPT = (
    "dark ambient drone with deep sub-bass, "
    "haunting ethereal female choir pads, "
    "slow cinematic atmosphere, 60 BPM"
)

# Output file
OUTPUT = "sa3_output.wav"

# === INFERENCE ===
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM free: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")

print(f"Loading {MODEL} model...")
model = StableAudioModel.from_pretrained(MODEL)
print("Model loaded.")

print(f"Generating {DURATION}s audio...")
audio = model.generate(PROMPT, duration=DURATION)

# Convert to numpy and save
# audio shape: (1, channels, samples) β€” e.g. (1, 2, 1323000) for 30s stereo
audio_np = audio.detach().cpu().squeeze(0).T.numpy()  # β†’ (samples, channels)
sf.write(OUTPUT, audio_np, 44100)

print(f"Saved: {OUTPUT}")
print(f"Duration: {len(audio_np)/44100:.1f}s, Size: {os.path.getsize(OUTPUT):,} bytes")
print("Done.")

Good luck. Long live ROCm!

Sign up or log in to comment