Instructions to use lrauch/BAT-vit-b16-pretrainedAS2M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use lrauch/BAT-vit-b16-pretrainedAS2M with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="lrauch/BAT-vit-b16-pretrainedAS2M", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("lrauch/BAT-vit-b16-pretrainedAS2M", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
BAT-vit-b16-pretrainedAS2M
BAT is a self-supervised audio foundation model from BAT: Better Audio Transformer Guided by Convex Gated Probing. This release contains the pretrained BAT ViT-B/16 encoder weights from AudioSet-2M.
The release contains encoder weights only. Downstream heads should be initialized separately for the task at hand.
Model Details
- Architecture: BAT ViT-B/16 encoder with gated self-attention
- Hidden size: 768
- Encoder depth: 12
- Attention heads: 12
- Input: log-mel spectrograms of shape
[batch, 1, 1024, 128] - Patch size:
16 x 16 - Pretraining data: AudioSet-2M
- Pretraining checkpoint step: 402610
Usage
Load the Hugging Face model directly
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained(
"lrauch/BAT-vit-b16-pretrainedAS2M",
trust_remote_code=True,
).eval()
# Already-preprocessed BAT log-mel features:
# [batch, channel, time, mel]
features = torch.randn(2, 1, 1024, 128)
with torch.no_grad():
outputs = model(input_features=features)
print(outputs.last_hidden_state.shape) # [2, 513, 768]
print(outputs.pooler_output.shape) # [2, 768]
print(outputs.patch_tokens.shape) # [2, 768, 64, 8]
Use the bundled convenience loader
import sys
import torch
from huggingface_hub import snapshot_download
local_dir = snapshot_download("lrauch/BAT-vit-b16-pretrainedAS2M")
sys.path.insert(0, local_dir)
from load_model import load_pretrained_encoder
model = load_pretrained_encoder(device="cuda")
fbank = torch.randn(2, 1, 1024, 128, device="cuda")
with torch.no_grad():
features = model.forward_encoder(fbank)
print(features.shape) # [2, 513, 768]
Use raw waveform preprocessing
For raw waveform preprocessing, use the bundled processor:
import sys
import torch
from huggingface_hub import snapshot_download
local_dir = snapshot_download("lrauch/BAT-vit-b16-pretrainedAS2M")
sys.path.insert(0, local_dir)
from load_model import load_audio_processor, load_pretrained_encoder
processor = load_audio_processor(device="cuda")
model = load_pretrained_encoder(device="cuda")
waveform = torch.randn(2, 16000 * 10, device="cuda")
input_features = processor(waveform)
with torch.no_grad():
outputs = model(input_features=input_features)
print(input_features.shape) # [2, 1, 1024, 128]
print(outputs.last_hidden_state.shape) # [2, 513, 768]
The processor follows the BAT training preprocessing: 16 kHz audio, mel spectrogram, power-to-dB compression, per-sample min-max normalization, padding or cropping to 1024 frames, then transposition to [batch, 1, time, mel].
Download only the checkpoint weights
If you want to integrate the pretrained encoder weights into your own codebase without using transformers, download the raw model.safetensors file:
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
weights_path = hf_hub_download(
repo_id="lrauch/BAT-vit-b16-pretrainedAS2M",
filename="model.safetensors",
)
state_dict = load_file(weights_path, device="cpu")
print(state_dict.keys())
The state dict uses encoder module names directly, for example:
cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
pre_norm.weight
pre_norm.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.gate.weight
...
Files
model.safetensors: pretrained BAT encoder weightsconfig.json: architecture and audio preprocessing configurationconfiguration_bat.py: custom Transformers configmodeling_bat.py: vendored BAT encoder architectureprocessing_bat.py: optional waveform-to-feature processorload_model.py: convenience loader
Required dependencies: torch, transformers, safetensors, huggingface_hub. The raw waveform processor also requires torchaudio.
Citation
@inproceedings{ghaffari2026batbetteraudiotransformer,
title={BAT: Better Audio Transformer Guided by Convex Gated Probing},
author={Houtan Ghaffari and Lukas Rauch and Christoph Scholz and Paul Devos},
year={2026},
booktitle={International Conference on Machine Learning (ICML)}
}
- Downloads last month
- 27