Community Forensics: Using Thousands of Generators to Train Fake Image Detectors
Paper โข 2411.04125 โข Published โข 1
Multi-Branch Frequency-Aware Detector: SwinV2 + SRM + DCT + FFT
A robust AI-generated image detector that combines semantic understanding with frequency-domain forensic analysis to detect AI-generated images from any source โ including high-quality outputs from Stable Diffusion, DALL-E, Midjourney, Flux, and 4,800+ other generators.
This model uses a novel 4-branch fusion architecture for maximum detection robustness:
Input Image (256ร256)
โ
โโโโโโผโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโ
โ โ โ โ โ
โผ โผ โผ โผ โผ
SwinV2 SRM HPF DCT Analyzer FFT Analyzer
(768d) (256d) (22d) (36d)
โ โ โ โ
โ โโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโ
โ Freq Features (314d)
โ โ
โ Freq Projection (128d)
โ โ
โโโโโโโโโโโโโโโโโโโโ
โ
Fusion MLP (896d โ 512 โ 128 โ 2)
โ
Real / AI-Generated
microsoft/swinv2-tiny-patch4-window8-256Total parameters: ~28.6M (compact enough for real-time inference)
OwensLab/CommunityForensics-Small (CVPR 2025)
During training, images are augmented with:
pip install transformers torch torchvision datasets evaluate accelerate trackio pillow scikit-learn
# Full training on GPU (recommended: A10G 24GB or better)
python train.py \
--num_train_epochs 5 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--learning_rate 2e-5 \
--hub_model_id your-username/ai-image-detector
# Quick test run
python train.py --test_mode
# Custom settings
python train.py \
--max_train_samples 50000 \
--num_train_epochs 3 \
--per_device_train_batch_size 8 \
--image_size 256
| Parameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning rate | 2e-5 |
| Weight decay | 0.01 |
| Warmup ratio | 0.1 |
| Batch size | 16 ร 4 GPUs = 64 effective |
| Epochs | 5 |
| Precision | bf16 |
| Label smoothing | 0.1 |
| Gradient checkpointing | โ |
| Image size | 256ร256 |
import torch
from train import FrequencyAwareDetector
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
# Load model
model = FrequencyAwareDetector()
state_dict = torch.load("model_state_dict.pt", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Preprocess
transform = Compose([
Resize((288, 288)),
CenterCrop((256, 256)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = Image.open("test.jpg").convert("RGB")
pixel_values = transform(img).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(pixel_values=pixel_values)
probs = torch.softmax(output["logits"], dim=1)
pred = probs.argmax(dim=1).item()
labels = {0: "Real", 1: "AI-Generated"}
print(f"Prediction: {labels[pred]} ({probs[0][pred]:.2%} confidence)")
# Single image
python inference.py --image photo.jpg
# URL
python inference.py --image https://example.com/image.png
# Batch (entire directory)
python inference.py --image_dir ./photos/
AI-generated images contain subtle artifacts that are invisible to the human eye but detectable in the frequency domain:
โโโ train.py # Full training script with model architecture
โโโ inference.py # Easy-to-use inference script
โโโ detector_config.json # Model configuration
โโโ model_state_dict.pt # Trained weights (after training)
โโโ README.md # This file
Apache 2.0