- AgriFM Γ PASTIS
- π Results β 5-Fold Cross-Validation
- π Per-Class IoU β All Folds
- β‘ Inference Speed
- ποΈ Model Architecture
- π Repository Structure
- π Quick Start
- π Training Configuration
- π¦ Dataset: PASTIS
- π§ Key Implementation Decisions
- π Comparison: Paper vs GitHub vs This Implementation
- π Training Dynamics (Fold 1 example)
- πΌοΈ Visualizations
- π Citation
- π License
AgriFM Γ PASTIS
Reimplementation of AgriFM β A Multi-source Temporal Remote Sensing Foundation Model for Agriculture Mapping (Li et al., 2025) β adapted for the PASTIS benchmark dataset and trained from scratch on a single AMD MI300X GPU (ROCm 7.0).
π Results β 5-Fold Cross-Validation
Summary
| Metric | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Fold 5 | Mean Β± Std |
|---|---|---|---|---|---|---|
| mFscore | 58.37% | 56.69% | 59.18% | 61.00% | 56.49% | 58.35% Β± 1.70% |
| mIoU | 44.35% | 42.90% | 45.54% | 46.80% | 42.58% | 44.43% Β± 1.62% |
| OA | 67.10% | 66.85% | 69.96% | 71.48% | 66.41% | 68.36% Β± 2.13% |
| Kappa | 60.02% | 59.71% | 62.79% | 64.33% | 58.99% | 61.17% Β± 2.19% |
| mPrecision | 52.78% | 51.33% | 54.51% | 57.40% | 51.24% | 53.45% Β± 2.47% |
| mRecall | 70.72% | 69.00% | 68.69% | 68.93% | 68.72% | 69.21% Β± 0.82% |
Official benchmark result: mFscore = 58.35% Β± 1.70% β trained from scratch, no pretrained weights, 32 minutes per fold on AMD MI300X.
π Per-Class IoU β All Folds
| Class | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Fold 5 | Mean |
|---|---|---|---|---|---|---|
| π½ Corn | 76.76% | 75.31% | 75.75% | 77.09% | 74.60% | 75.90% |
| πΏ Winter rapeseed | 74.17% | 76.05% | 77.14% | 77.49% | 72.90% | 75.55% |
| πΎ Beet | 73.32% | 72.52% | 78.01% | 71.87% | 73.78% | 73.90% |
| πΎ Soft winter wheat | 72.22% | 68.92% | 71.64% | 71.18% | 73.37% | 71.47% |
| πΏ Soybeans | 66.81% | 55.03% | 63.35% | 62.73% | 59.49% | 61.48% |
| πΎ Winter barley | 56.56% | 59.08% | 57.70% | 62.90% | 56.94% | 58.64% |
| π± Meadow | 56.56% | 55.96% | 57.47% | 58.18% | 55.29% | 56.69% |
| π» Sunflower | 53.94% | 48.13% | 64.54% | 57.16% | 50.59% | 54.87% |
| π« Background | 48.39% | 50.20% | 51.60% | 54.88% | 49.83% | 50.98% |
| πΎ Winter durum wheat | 43.30% | 43.68% | 49.19% | 47.26% | 39.38% | 44.56% |
| π₯ Potatoes | 37.33% | 37.75% | 34.93% | 42.60% | 29.07% | 36.34% |
| πΏ Grapevine | 36.16% | 36.62% | 37.52% | 38.60% | 34.60% | 36.70% |
| πΎ Spring barley | 30.51% | 40.00% | 32.87% | 38.06% | 33.97% | 35.08% |
| πΏ Leguminous fodder | 23.17% | 20.88% | 21.02% | 28.88% | 18.12% | 22.41% |
| π Fruits/veg/flowers | 22.27% | 19.75% | 20.97% | 22.67% | 21.97% | 21.53% |
| πΎ Winter triticale | 25.01% | 19.33% | 24.30% | 23.72% | 19.32% | 22.34% |
| πΎ Mixed cereal | 17.63% | 12.64% | 13.20% | 16.81% | 13.02% | 14.66% |
| π Orchard | 15.25% | 13.11% | 26.05% | 18.81% | 13.14% | 17.27% |
| πΏ Sorghum | 13.29% | 10.12% | 8.04% | 18.40% | 19.68% | 13.91% |
β‘ Inference Speed
Measured on AMD MI300X (191.7 GB VRAM), AgriFM-small (39.6M params):
| Batch Size | Time (ms) | Throughput | VRAM Used |
|---|---|---|---|
| 1 | 7.8 ms | 128.9 patches/sec | 1.12 GB |
| 4 | 19.6 ms | 203.8 patches/sec | 2.08 GB |
| 8 | 37.1 ms | 215.4 patches/sec | 3.40 GB |
| 16 | 73.4 ms | 218.0 patches/sec | 6.03 GB |
| 32 | 143.0 ms | 223.8 patches/sec | 11.31 GB |
| 64 | 281.0 ms | 227.8 patches/sec | 21.85 GB |
Each patch = 32 temporal frames Γ 10 spectral bands Γ 128Γ128 pixels of Sentinel-2 imagery.
ποΈ Model Architecture
Input: (B, 32, 10, 128, 128) β 32 frames, 10 S2 bands, 128Γ128 pixels
β
SwinPatchEmbed3D patch_size=(4,2,2) β (B, 64, 8, 64, 64)
β
SwinTransformer3D β 4 stages with synchronized spatiotemporal downsampling
Stage 1: depth=2, heads=2, dim=64 β (B, 64, 8, 64, 64)
Stage 2: depth=2, heads=4, dim=128 β (B, 128, 4, 32, 32)
Stage 3: depth=6, heads=8, dim=256 β (B, 256, 2, 16, 16)
Stage 4: depth=2, heads=16, dim=512 β (B, 512, 1, 8, 8)
β
feature_fusion='cat' β (B, 512, 8, 8) β temporal dim collapsed into channels
β
MultiFusionNeck β 3-level U-Net decoder with skip connections
Level 1: 8Γ8 β 16Γ16 (skip from Stage 3)
Level 2: 16Γ16 β 32Γ32 (skip from Stage 2)
Level 3: 32Γ32 β 64Γ64 (skip from Stage 1)
Final: 64Γ64 β 128Γ128
β
CropFCNHead Conv(512β256)+ReLU+Conv(256β20)
β
Output: (B, 20, 128, 128) β per-pixel crop type probabilities
Key Innovation: Synchronized Spatiotemporal PatchMerging
Unlike standard Video Swin Transformer which only downsamples spatially, AgriFM's PatchMerging with mean_frame_down=True simultaneously:
- Spatially: 4 strided samples concatenated β 4C channels
- Temporally: mean pooling over D_step consecutive frames
This keeps spatial and temporal scales synchronized throughout the hierarchy.
π Repository Structure
AgriFM_PASTIS/
βββ models/
β βββ video_swin_transformer.py # VST backbone + synchronized PatchMerging
β βββ neck.py # MultiFusionNeck (U-Net decoder)
β βββ heads.py # CropFCNHead
β βββ agrifm.py # Full model + build_agrifm_pastis_small/tiny/base
βββ datasets/
β βββ pastis_dataset.py # Official 5-fold splits via geopandas
βββ losses/
β βββ loss.py # CropCELoss + hard mining + class weights
βββ evaluation/
β βββ metrics.py # OA, mIoU, mFscore, Precision, Recall, Kappa
βββ train.py # Training loop with colored output + logging
βββ visualize_results.py # 8-plot visualization suite per fold
βββ visualize_all_folds.py # Cross-validation summary plots
βββ requirements.txt
βββ checkpoints/
β βββ fold1_v3_best_model.pth # Best checkpoint: fold1, epoch 92, mFscore=58.37%
βββ results/
βββ fold1/ # Logs, metrics, plots for fold 1
βββ fold2/ ... fold5/ # Same for folds 2β5
βββ all_folds_summary/ # Cross-validation summary plots
π Quick Start
Installation
git clone https://huggingface.co/Dhruv1000/AgriFM_PASTIS
cd AgriFM_PASTIS
pip install torch torchvision timm einops geopandas matplotlib seaborn scikit-learn
Training (single fold)
python train.py \
--data_root /path/to/PASTIS \
--fold 1 \
--small_model \
--epochs 100 \
--batch_size 16 \
--lr 5e-5 \
--weight_decay 0.05 \
--warmup_iters 500 \
--num_workers 4 \
--amp \
--work_dir ./work_dirs/fold1
Training (all 5 folds)
for fold in 1 2 3 4 5; do
python train.py \
--data_root /path/to/PASTIS \
--fold $fold \
--small_model \
--epochs 100 \
--batch_size 16 \
--lr 5e-5 \
--weight_decay 0.05 \
--work_dir ./work_dirs/fold${fold}
done
Inference
import torch
from models.agrifm import build_agrifm_pastis_small
model = build_agrifm_pastis_small(num_classes=20)
ckpt = torch.load('checkpoints/fold1_v3_best_model.pth',
weights_only=False)
model.load_state_dict(ckpt['model'])
model.eval()
# Input: (B, T, C, H, W) = (batch, 32 frames, 10 bands, 128px, 128px)
x = torch.randn(1, 32, 10, 128, 128)
logits = model(x) # (1, 20, 128, 128)
pred = logits.argmax(dim=1) # (1, 128, 128) β class indices
Visualize Results
# Per-fold visualization (8 plots)
python visualize_results.py \
--work_dir ./work_dirs/fold1 \
--data_root /path/to/PASTIS \
--fold 1 --model_size small
# All-folds cross-validation summary
python visualize_all_folds.py
π Training Configuration
| Parameter | Value |
|---|---|
| Model | AgriFM-small (39.6M params) |
| Optimizer | AdamW (Ξ²=(0.9, 0.999)) |
| Learning rate | 5e-5 |
| Weight decay | 0.05 |
| LR schedule | Linear warmup (500 iters) β cosine decay to 1e-6 |
| Batch size | 16 |
| Epochs | 100 |
| Loss | Cross-entropy + hard mining (top 50%) + inverse-freq class weights |
| Augmentation | Random h-flip, v-flip, 90/180/270Β° rotation |
| AMP | Enabled (torch.amp, ROCm 7.0 compatible) |
| Gradient clipping | max_norm = 5.0 |
| Input frames | 32 (uniform sampled from 61) |
| Normalization | Per-fold mean/std from NORM_S2_patch.json |
π¦ Dataset: PASTIS
| Property | Details |
|---|---|
| Total patches | 2,433 geo-referenced tiles |
| Patch size | 128 Γ 128 pixels at 10m resolution |
| Satellite | Sentinel-2 (10 spectral bands) |
| Temporal length | 61 observations per patch |
| Label hierarchy | 3 levels β we use channel 0 (20 coarse classes) |
| Folds | 5 official geographic folds (~487β496 patches each) |
| Fold assignment | From metadata.geojson via geopandas |
| Ignore index | Class 19 (Void) excluded from all metrics |
PASTIS Classes (20 coarse classes)
| ID | Class | Avg IoU (5-fold) |
|---|---|---|
| 0 | Background | 50.98% |
| 1 | Meadow | 56.69% |
| 2 | Soft winter wheat | 71.47% |
| 3 | Corn | 75.90% |
| 4 | Winter barley | 58.64% |
| 5 | Winter rapeseed | 75.55% |
| 6 | Spring barley | 35.08% |
| 7 | Sunflower | 54.87% |
| 8 | Grapevine | 36.70% |
| 9 | Beet | 73.90% |
| 10 | Winter triticale | 22.34% |
| 11 | Winter durum wheat | 44.56% |
| 12 | Fruits/veg/flowers | 21.53% |
| 13 | Potatoes | 36.34% |
| 14 | Leguminous fodder | 22.41% |
| 15 | Soybeans | 61.48% |
| 16 | Orchard | 17.27% |
| 17 | Mixed cereal | 14.66% |
| 18 | Sorghum | 13.91% |
| 19 | Void (ignore) | β |
π§ Key Implementation Decisions
1. No MMSeg/MMEngine dependency
Original GitHub uses MMSegmentation. This reimplementation uses pure PyTorch only β no registry system, no config framework. Runs anywhere PyTorch runs.
2. Official fold splits
Initial fold assignment by patch ID prefix caused catastrophic distribution mismatch (W.Durum ratio: 88Γ between train/val). Fixed by loading official fold assignments from metadata.geojson using geopandas. Val loss dropped from 34 β 0.67 at epoch 1.
3. Class-weighted loss
Added compute_class_weights() with inverse-frequency weighting. Background weight = 0.01, Winter Durum Wheat weight = 7.24 β essential for learning rare crops.
4. Model size variants
| Variant | Params | embed_dim | depths | Use case |
|---|---|---|---|---|
base |
196.2M | 128 | [2,2,18,2] | Paper-faithful, needs pretrained weights |
small β
|
39.6M | 64 | [2,2,6,2] | Used here β recommended for scratch |
tiny |
9.6M | 32 | [2,2,4,2] | Backup if small overfits |
5. AMD ROCm compatibility
torch.amp.GradScaler('cuda')andtorch.amp.autocast('cuda')(new API)torch.load(..., weights_only=False)for numpy scalar checkpointsMIOPEN_LOG_LEVEL=0to suppress kernel selection warnings (not errors)
6. Added Cohen's Kappa metric
Not in original codebase. Computed from full confusion matrix. Final 5-fold mean Kappa = 61.17% (substantial agreement).
π Comparison: Paper vs GitHub vs This Implementation
| Aspect | Paper | GitHub | Ours |
|---|---|---|---|
| Framework | PyTorch | PyTorch + MMSeg + MMEngine | Pure PyTorch |
| Dataset | 25.2M global images | Custom HDF5 | PASTIS .npy |
| Input size | 224Γ224 | 256Γ256 | 128Γ128 |
| Patch embed | SwinPatchEmbed3D | Identical | Identical β |
| Patch size | (4,2,2) | (4,2,2) | Identical β |
| PatchMerging | Synchronized | mean_frame_down=True | Identical β |
| feature_fusion | Not explicit | 'cat' | Identical β |
| Window size | Not explicit | (8,7,7) | Identical β |
| Neck | MultiFusionNeck | MultiFusionNeck | Identical β |
| Head | CropFCNHead | CropFCNHead | Identical β |
| Pre-training | LC fractions + L1 | Same | Not implemented |
| Class weights | Not mentioned | Not implemented | Added β |
| Kappa metric | Not reported | Not implemented | Added β |
| Augmentation | Not detailed | flip/rotate | flip + rotate β |
| Grad clipping | Not specified | None | max_norm=5.0 β |
| GPU | 10Γ NVIDIA L40 | 10Γ NVIDIA L40 | 1Γ AMD MI300X |
π Training Dynamics (Fold 1 example)
| Epoch | Train Loss | Val Loss | mFscore | mIoU | Kappa |
|---|---|---|---|---|---|
| 1 | 0.878 | 0.671 | 2.79% | 1.47% | 2.39% |
| 4 | 0.431 | 0.472 | 20.08% | 12.34% | 10.76% |
| 10 | 0.320 | 0.380 | ~33% | ~22% | ~24% |
| 18 | 0.222 | 0.323 | 35.55% | 24.17% | 24.93% |
| 55 | 0.083 | 0.363 | 53.37% | 39.92% | 53.16% |
| 92 | 0.050 | 0.350 | 58.20% | 44.20% | 60.0% |
| 100 | 0.048 | 0.360 | 57.90% | 44.10% | 59.8% |
Total training time per fold: ~32 minutes on AMD MI300X.
πΌοΈ Visualizations
The repository includes 8 visualization plots per fold plus 6 cross-validation summary plots:
Per-fold (in results/fold{N}/plots/):
0_summary_card.pngβ Dark-theme dashboard with all metrics1_training_curves.pngβ 6-panel training history2_per_class_iou.pngβ Sorted horizontal bar chart3_metrics_radar.pngβ Spider/radar chart of all 6 metrics4_confusion_matrix.pngβ Raw + normalized confusion matrices5_prediction_maps.pngβ 6 sample RGB / GT / Pred / Error maps6_class_iou_scatter.pngβ Per-class IoU scatter with mIoU line7_overfitting_analysis.pngβ Train vs val loss gap over epochs
Cross-validation (results/all_folds_summary/):
CV1_metrics_all_folds.pngβ Grouped bar chart all foldsCV2_per_class_heatmap.pngβ IoU heatmap: folds Γ classesCV3_training_curves_overlay.pngβ All folds on same axesCV4_summary_table.pngβ Mean Β± std tableCV5_mean_per_class_iou.pngβ Mean IoU with std error barsCV6_metric_boxplots.pngβ Box plots with individual fold points
π Citation
If you use this code or results, please cite the original AgriFM paper:
@article{li2025agrifm,
title = {AgriFM: A Multi-source Temporal Remote Sensing Foundation Model
for Agriculture Mapping},
author = {Li, Wenyuan and Liang, Shunlin and Chen, Kun and others},
journal = {Remote Sensing of Environment},
year = {2025},
doi = {10.1016/j.rse.2026.115234}
}
And the PASTIS dataset:
@inproceedings{garnot2021pastis,
title = {Panoptic Segmentation of Satellite Image Time Series
with Convolutional Temporal Attention Networks},
author = {Garnot, Vivien Sainte Fare and Landrieu, Loic},
booktitle = {ICCV},
year = {2021}
}
π License
Apache 2.0 β see LICENSE for details.
Trained on AMD MI300X Β· ROCm 7.0 Β· PyTorch 2.x Β· April 2026
Evaluation results
- Mean IoU (5-fold CV) on PASTISself-reported44.430
- F1 Score (5-fold CV) on PASTISself-reported58.350
- Overall Accuracy (5-fold CV) on PASTISself-reported68.360