BAT-vit-b16-pretrainedAS2M

BAT is a self-supervised audio foundation model from BAT: Better Audio Transformer Guided by Convex Gated Probing. This release contains the pretrained BAT ViT-B/16 encoder weights from AudioSet-2M.

The release contains encoder weights only. Downstream heads should be initialized separately for the task at hand.

Model Details

  • Architecture: BAT ViT-B/16 encoder with gated self-attention
  • Hidden size: 768
  • Encoder depth: 12
  • Attention heads: 12
  • Input: log-mel spectrograms of shape [batch, 1, 1024, 128]
  • Patch size: 16 x 16
  • Pretraining data: AudioSet-2M
  • Pretraining checkpoint step: 402610

Usage

Load the Hugging Face model directly

import torch
from transformers import AutoModel

model = AutoModel.from_pretrained(
    "lrauch/BAT-vit-b16-pretrainedAS2M",
    trust_remote_code=True,
).eval()

# Already-preprocessed BAT log-mel features:
# [batch, channel, time, mel]
features = torch.randn(2, 1, 1024, 128)

with torch.no_grad():
    outputs = model(input_features=features)

print(outputs.last_hidden_state.shape)  # [2, 513, 768]
print(outputs.pooler_output.shape)      # [2, 768]
print(outputs.patch_tokens.shape)       # [2, 768, 64, 8]

Use the bundled convenience loader

import sys
import torch
from huggingface_hub import snapshot_download

local_dir = snapshot_download("lrauch/BAT-vit-b16-pretrainedAS2M")
sys.path.insert(0, local_dir)

from load_model import load_pretrained_encoder

model = load_pretrained_encoder(device="cuda")

fbank = torch.randn(2, 1, 1024, 128, device="cuda")

with torch.no_grad():
    features = model.forward_encoder(fbank)

print(features.shape)  # [2, 513, 768]

Use raw waveform preprocessing

For raw waveform preprocessing, use the bundled processor:

import sys
import torch
from huggingface_hub import snapshot_download

local_dir = snapshot_download("lrauch/BAT-vit-b16-pretrainedAS2M")
sys.path.insert(0, local_dir)

from load_model import load_audio_processor, load_pretrained_encoder

processor = load_audio_processor(device="cuda")
model = load_pretrained_encoder(device="cuda")

waveform = torch.randn(2, 16000 * 10, device="cuda")

input_features = processor(waveform)

with torch.no_grad():
    outputs = model(input_features=input_features)

print(input_features.shape)             # [2, 1, 1024, 128]
print(outputs.last_hidden_state.shape)  # [2, 513, 768]

The processor follows the BAT training preprocessing: 16 kHz audio, mel spectrogram, power-to-dB compression, per-sample min-max normalization, padding or cropping to 1024 frames, then transposition to [batch, 1, time, mel].

Download only the checkpoint weights

If you want to integrate the pretrained encoder weights into your own codebase without using transformers, download the raw model.safetensors file:

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

weights_path = hf_hub_download(
    repo_id="lrauch/BAT-vit-b16-pretrainedAS2M",
    filename="model.safetensors",
)

state_dict = load_file(weights_path, device="cpu")
print(state_dict.keys())

The state dict uses encoder module names directly, for example:

cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
pre_norm.weight
pre_norm.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.gate.weight
...

Files

  • model.safetensors: pretrained BAT encoder weights
  • config.json: architecture and audio preprocessing configuration
  • configuration_bat.py: custom Transformers config
  • modeling_bat.py: vendored BAT encoder architecture
  • processing_bat.py: optional waveform-to-feature processor
  • load_model.py: convenience loader

Required dependencies: torch, transformers, safetensors, huggingface_hub. The raw waveform processor also requires torchaudio.

Citation

@inproceedings{ghaffari2026batbetteraudiotransformer,
      title={BAT: Better Audio Transformer Guided by Convex Gated Probing},
      author={Houtan Ghaffari and Lukas Rauch and Christoph Scholz and Paul Devos},
      year={2026},
      booktitle={International Conference on Machine Learning (ICML)}
}
Downloads last month
27
Safetensors
Model size
92.7M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support