AgriFM Γ— PASTIS

Reimplementation of AgriFM β€” A Multi-source Temporal Remote Sensing Foundation Model for Agriculture Mapping (Li et al., 2025) β€” adapted for the PASTIS benchmark dataset and trained from scratch on a single AMD MI300X GPU (ROCm 7.0).


πŸ† Results β€” 5-Fold Cross-Validation

Summary

Metric Fold 1 Fold 2 Fold 3 Fold 4 Fold 5 Mean Β± Std
mFscore 58.37% 56.69% 59.18% 61.00% 56.49% 58.35% Β± 1.70%
mIoU 44.35% 42.90% 45.54% 46.80% 42.58% 44.43% Β± 1.62%
OA 67.10% 66.85% 69.96% 71.48% 66.41% 68.36% Β± 2.13%
Kappa 60.02% 59.71% 62.79% 64.33% 58.99% 61.17% Β± 2.19%
mPrecision 52.78% 51.33% 54.51% 57.40% 51.24% 53.45% Β± 2.47%
mRecall 70.72% 69.00% 68.69% 68.93% 68.72% 69.21% Β± 0.82%

Official benchmark result: mFscore = 58.35% Β± 1.70% β€” trained from scratch, no pretrained weights, 32 minutes per fold on AMD MI300X.


πŸ“Š Per-Class IoU β€” All Folds

Class Fold 1 Fold 2 Fold 3 Fold 4 Fold 5 Mean
🌽 Corn 76.76% 75.31% 75.75% 77.09% 74.60% 75.90%
🌿 Winter rapeseed 74.17% 76.05% 77.14% 77.49% 72.90% 75.55%
🌾 Beet 73.32% 72.52% 78.01% 71.87% 73.78% 73.90%
🌾 Soft winter wheat 72.22% 68.92% 71.64% 71.18% 73.37% 71.47%
🌿 Soybeans 66.81% 55.03% 63.35% 62.73% 59.49% 61.48%
🌾 Winter barley 56.56% 59.08% 57.70% 62.90% 56.94% 58.64%
🌱 Meadow 56.56% 55.96% 57.47% 58.18% 55.29% 56.69%
🌻 Sunflower 53.94% 48.13% 64.54% 57.16% 50.59% 54.87%
🟫 Background 48.39% 50.20% 51.60% 54.88% 49.83% 50.98%
🌾 Winter durum wheat 43.30% 43.68% 49.19% 47.26% 39.38% 44.56%
πŸ₯” Potatoes 37.33% 37.75% 34.93% 42.60% 29.07% 36.34%
🌿 Grapevine 36.16% 36.62% 37.52% 38.60% 34.60% 36.70%
🌾 Spring barley 30.51% 40.00% 32.87% 38.06% 33.97% 35.08%
🌿 Leguminous fodder 23.17% 20.88% 21.02% 28.88% 18.12% 22.41%
🍎 Fruits/veg/flowers 22.27% 19.75% 20.97% 22.67% 21.97% 21.53%
🌾 Winter triticale 25.01% 19.33% 24.30% 23.72% 19.32% 22.34%
🌾 Mixed cereal 17.63% 12.64% 13.20% 16.81% 13.02% 14.66%
πŸ‘ Orchard 15.25% 13.11% 26.05% 18.81% 13.14% 17.27%
🌿 Sorghum 13.29% 10.12% 8.04% 18.40% 19.68% 13.91%

⚑ Inference Speed

Measured on AMD MI300X (191.7 GB VRAM), AgriFM-small (39.6M params):

Batch Size Time (ms) Throughput VRAM Used
1 7.8 ms 128.9 patches/sec 1.12 GB
4 19.6 ms 203.8 patches/sec 2.08 GB
8 37.1 ms 215.4 patches/sec 3.40 GB
16 73.4 ms 218.0 patches/sec 6.03 GB
32 143.0 ms 223.8 patches/sec 11.31 GB
64 281.0 ms 227.8 patches/sec 21.85 GB

Each patch = 32 temporal frames Γ— 10 spectral bands Γ— 128Γ—128 pixels of Sentinel-2 imagery.


πŸ—οΈ Model Architecture

Input: (B, 32, 10, 128, 128)  ← 32 frames, 10 S2 bands, 128Γ—128 pixels
    ↓
SwinPatchEmbed3D  patch_size=(4,2,2)  β†’ (B, 64, 8, 64, 64)
    ↓
SwinTransformer3D β€” 4 stages with synchronized spatiotemporal downsampling
  Stage 1: depth=2, heads=2,  dim=64   β†’ (B, 64,  8, 64, 64)
  Stage 2: depth=2, heads=4,  dim=128  β†’ (B, 128, 4, 32, 32)
  Stage 3: depth=6, heads=8,  dim=256  β†’ (B, 256, 2, 16, 16)
  Stage 4: depth=2, heads=16, dim=512  β†’ (B, 512, 1,  8,  8)
    ↓
feature_fusion='cat' β†’ (B, 512, 8, 8)  ← temporal dim collapsed into channels
    ↓
MultiFusionNeck  β€” 3-level U-Net decoder with skip connections
  Level 1: 8Γ—8   β†’ 16Γ—16  (skip from Stage 3)
  Level 2: 16Γ—16 β†’ 32Γ—32  (skip from Stage 2)
  Level 3: 32Γ—32 β†’ 64Γ—64  (skip from Stage 1)
  Final:   64Γ—64 β†’ 128Γ—128
    ↓
CropFCNHead  Conv(512β†’256)+ReLU+Conv(256β†’20)
    ↓
Output: (B, 20, 128, 128)  ← per-pixel crop type probabilities

Key Innovation: Synchronized Spatiotemporal PatchMerging

Unlike standard Video Swin Transformer which only downsamples spatially, AgriFM's PatchMerging with mean_frame_down=True simultaneously:

  • Spatially: 4 strided samples concatenated β†’ 4C channels
  • Temporally: mean pooling over D_step consecutive frames

This keeps spatial and temporal scales synchronized throughout the hierarchy.


πŸ“ Repository Structure

AgriFM_PASTIS/
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ video_swin_transformer.py   # VST backbone + synchronized PatchMerging
β”‚   β”œβ”€β”€ neck.py                     # MultiFusionNeck (U-Net decoder)
β”‚   β”œβ”€β”€ heads.py                    # CropFCNHead
β”‚   └── agrifm.py                   # Full model + build_agrifm_pastis_small/tiny/base
β”œβ”€β”€ datasets/
β”‚   └── pastis_dataset.py           # Official 5-fold splits via geopandas
β”œβ”€β”€ losses/
β”‚   └── loss.py                     # CropCELoss + hard mining + class weights
β”œβ”€β”€ evaluation/
β”‚   └── metrics.py                  # OA, mIoU, mFscore, Precision, Recall, Kappa
β”œβ”€β”€ train.py                        # Training loop with colored output + logging
β”œβ”€β”€ visualize_results.py            # 8-plot visualization suite per fold
β”œβ”€β”€ visualize_all_folds.py          # Cross-validation summary plots
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ checkpoints/
β”‚   └── fold1_v3_best_model.pth     # Best checkpoint: fold1, epoch 92, mFscore=58.37%
└── results/
    β”œβ”€β”€ fold1/                      # Logs, metrics, plots for fold 1
    β”œβ”€β”€ fold2/ ... fold5/           # Same for folds 2–5
    └── all_folds_summary/          # Cross-validation summary plots

πŸš€ Quick Start

Installation

git clone https://huggingface.co/Dhruv1000/AgriFM_PASTIS
cd AgriFM_PASTIS
pip install torch torchvision timm einops geopandas matplotlib seaborn scikit-learn

Training (single fold)

python train.py \
    --data_root /path/to/PASTIS \
    --fold 1 \
    --small_model \
    --epochs 100 \
    --batch_size 16 \
    --lr 5e-5 \
    --weight_decay 0.05 \
    --warmup_iters 500 \
    --num_workers 4 \
    --amp \
    --work_dir ./work_dirs/fold1

Training (all 5 folds)

for fold in 1 2 3 4 5; do
    python train.py \
        --data_root /path/to/PASTIS \
        --fold $fold \
        --small_model \
        --epochs 100 \
        --batch_size 16 \
        --lr 5e-5 \
        --weight_decay 0.05 \
        --work_dir ./work_dirs/fold${fold}
done

Inference

import torch
from models.agrifm import build_agrifm_pastis_small

model = build_agrifm_pastis_small(num_classes=20)
ckpt  = torch.load('checkpoints/fold1_v3_best_model.pth',
                   weights_only=False)
model.load_state_dict(ckpt['model'])
model.eval()

# Input: (B, T, C, H, W) = (batch, 32 frames, 10 bands, 128px, 128px)
x      = torch.randn(1, 32, 10, 128, 128)
logits = model(x)   # (1, 20, 128, 128)
pred   = logits.argmax(dim=1)  # (1, 128, 128) β€” class indices

Visualize Results

# Per-fold visualization (8 plots)
python visualize_results.py \
    --work_dir ./work_dirs/fold1 \
    --data_root /path/to/PASTIS \
    --fold 1 --model_size small

# All-folds cross-validation summary
python visualize_all_folds.py

πŸ“‹ Training Configuration

Parameter Value
Model AgriFM-small (39.6M params)
Optimizer AdamW (Ξ²=(0.9, 0.999))
Learning rate 5e-5
Weight decay 0.05
LR schedule Linear warmup (500 iters) β†’ cosine decay to 1e-6
Batch size 16
Epochs 100
Loss Cross-entropy + hard mining (top 50%) + inverse-freq class weights
Augmentation Random h-flip, v-flip, 90/180/270Β° rotation
AMP Enabled (torch.amp, ROCm 7.0 compatible)
Gradient clipping max_norm = 5.0
Input frames 32 (uniform sampled from 61)
Normalization Per-fold mean/std from NORM_S2_patch.json

πŸ“¦ Dataset: PASTIS

Property Details
Total patches 2,433 geo-referenced tiles
Patch size 128 Γ— 128 pixels at 10m resolution
Satellite Sentinel-2 (10 spectral bands)
Temporal length 61 observations per patch
Label hierarchy 3 levels β€” we use channel 0 (20 coarse classes)
Folds 5 official geographic folds (~487–496 patches each)
Fold assignment From metadata.geojson via geopandas
Ignore index Class 19 (Void) excluded from all metrics

PASTIS Classes (20 coarse classes)

ID Class Avg IoU (5-fold)
0 Background 50.98%
1 Meadow 56.69%
2 Soft winter wheat 71.47%
3 Corn 75.90%
4 Winter barley 58.64%
5 Winter rapeseed 75.55%
6 Spring barley 35.08%
7 Sunflower 54.87%
8 Grapevine 36.70%
9 Beet 73.90%
10 Winter triticale 22.34%
11 Winter durum wheat 44.56%
12 Fruits/veg/flowers 21.53%
13 Potatoes 36.34%
14 Leguminous fodder 22.41%
15 Soybeans 61.48%
16 Orchard 17.27%
17 Mixed cereal 14.66%
18 Sorghum 13.91%
19 Void (ignore) β€”

πŸ”§ Key Implementation Decisions

1. No MMSeg/MMEngine dependency

Original GitHub uses MMSegmentation. This reimplementation uses pure PyTorch only β€” no registry system, no config framework. Runs anywhere PyTorch runs.

2. Official fold splits

Initial fold assignment by patch ID prefix caused catastrophic distribution mismatch (W.Durum ratio: 88Γ— between train/val). Fixed by loading official fold assignments from metadata.geojson using geopandas. Val loss dropped from 34 β†’ 0.67 at epoch 1.

3. Class-weighted loss

Added compute_class_weights() with inverse-frequency weighting. Background weight = 0.01, Winter Durum Wheat weight = 7.24 β€” essential for learning rare crops.

4. Model size variants

Variant Params embed_dim depths Use case
base 196.2M 128 [2,2,18,2] Paper-faithful, needs pretrained weights
small βœ… 39.6M 64 [2,2,6,2] Used here β€” recommended for scratch
tiny 9.6M 32 [2,2,4,2] Backup if small overfits

5. AMD ROCm compatibility

  • torch.amp.GradScaler('cuda') and torch.amp.autocast('cuda') (new API)
  • torch.load(..., weights_only=False) for numpy scalar checkpoints
  • MIOPEN_LOG_LEVEL=0 to suppress kernel selection warnings (not errors)

6. Added Cohen's Kappa metric

Not in original codebase. Computed from full confusion matrix. Final 5-fold mean Kappa = 61.17% (substantial agreement).


πŸ†š Comparison: Paper vs GitHub vs This Implementation

Aspect Paper GitHub Ours
Framework PyTorch PyTorch + MMSeg + MMEngine Pure PyTorch
Dataset 25.2M global images Custom HDF5 PASTIS .npy
Input size 224Γ—224 256Γ—256 128Γ—128
Patch embed SwinPatchEmbed3D Identical Identical βœ“
Patch size (4,2,2) (4,2,2) Identical βœ“
PatchMerging Synchronized mean_frame_down=True Identical βœ“
feature_fusion Not explicit 'cat' Identical βœ“
Window size Not explicit (8,7,7) Identical βœ“
Neck MultiFusionNeck MultiFusionNeck Identical βœ“
Head CropFCNHead CropFCNHead Identical βœ“
Pre-training LC fractions + L1 Same Not implemented
Class weights Not mentioned Not implemented Added βœ“
Kappa metric Not reported Not implemented Added βœ“
Augmentation Not detailed flip/rotate flip + rotate βœ“
Grad clipping Not specified None max_norm=5.0 βœ“
GPU 10Γ— NVIDIA L40 10Γ— NVIDIA L40 1Γ— AMD MI300X

πŸ“ˆ Training Dynamics (Fold 1 example)

Epoch Train Loss Val Loss mFscore mIoU Kappa
1 0.878 0.671 2.79% 1.47% 2.39%
4 0.431 0.472 20.08% 12.34% 10.76%
10 0.320 0.380 ~33% ~22% ~24%
18 0.222 0.323 35.55% 24.17% 24.93%
55 0.083 0.363 53.37% 39.92% 53.16%
92 0.050 0.350 58.20% 44.20% 60.0%
100 0.048 0.360 57.90% 44.10% 59.8%

Total training time per fold: ~32 minutes on AMD MI300X.


πŸ–ΌοΈ Visualizations

The repository includes 8 visualization plots per fold plus 6 cross-validation summary plots:

Per-fold (in results/fold{N}/plots/):

  • 0_summary_card.png β€” Dark-theme dashboard with all metrics
  • 1_training_curves.png β€” 6-panel training history
  • 2_per_class_iou.png β€” Sorted horizontal bar chart
  • 3_metrics_radar.png β€” Spider/radar chart of all 6 metrics
  • 4_confusion_matrix.png β€” Raw + normalized confusion matrices
  • 5_prediction_maps.png β€” 6 sample RGB / GT / Pred / Error maps
  • 6_class_iou_scatter.png β€” Per-class IoU scatter with mIoU line
  • 7_overfitting_analysis.png β€” Train vs val loss gap over epochs

Cross-validation (results/all_folds_summary/):

  • CV1_metrics_all_folds.png β€” Grouped bar chart all folds
  • CV2_per_class_heatmap.png β€” IoU heatmap: folds Γ— classes
  • CV3_training_curves_overlay.png β€” All folds on same axes
  • CV4_summary_table.png β€” Mean Β± std table
  • CV5_mean_per_class_iou.png β€” Mean IoU with std error bars
  • CV6_metric_boxplots.png β€” Box plots with individual fold points

πŸ“š Citation

If you use this code or results, please cite the original AgriFM paper:

@article{li2025agrifm,
  title     = {AgriFM: A Multi-source Temporal Remote Sensing Foundation Model
               for Agriculture Mapping},
  author    = {Li, Wenyuan and Liang, Shunlin and Chen, Kun and others},
  journal   = {Remote Sensing of Environment},
  year      = {2025},
  doi       = {10.1016/j.rse.2026.115234}
}

And the PASTIS dataset:

@inproceedings{garnot2021pastis,
  title     = {Panoptic Segmentation of Satellite Image Time Series
               with Convolutional Temporal Attention Networks},
  author    = {Garnot, Vivien Sainte Fare and Landrieu, Loic},
  booktitle = {ICCV},
  year      = {2021}
}

πŸ“„ License

Apache 2.0 β€” see LICENSE for details.


Trained on AMD MI300X Β· ROCm 7.0 Β· PyTorch 2.x Β· April 2026

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Evaluation results