Arko007/saffron-verify
Viewer • Updated • 208 • 11
How to use Arko007/saffron-verify-pretrained with timm:
import timm
model = timm.create_model("hf_hub:Arko007/saffron-verify-pretrained", pretrained=True)A high-accuracy saffron quality classification model trained on the Arko007/saffron-verify dataset. The model classifies saffron images into three grades: Mogra, Lacha, and Adulterated.
Best checkpoint saved at Epoch 13 with early stopping triggered at Epoch 20.
| Metric | Value |
|---|---|
| Macro F1 | 0.9888 |
| Accuracy | 98.96% |
| Val Loss | 0.3562 |
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| mogra | 0.98 | 0.98 | 0.98 | 56 |
| lacha | 0.98 | 0.98 | 0.98 | 64 |
| adulterated | 1.00 | 1.00 | 1.00 | 72 |
| macro avg | 0.99 | 0.99 | 0.99 | 192 |
| Parameter | Value |
|---|---|
| Base Model | convnext_base (ImageNet-21k pretrained via timm) |
| Image Size | 512 × 512 |
| Effective Batch Size | 96 (16 per GPU × 2 GPUs × 3 grad accum) |
| Optimizer | AdamW (β₁=0.9, β₂=0.999) |
| Learning Rate | 5e-6 (backbone) / 2.5e-5 (head) |
| Scheduler | Warmup (5 epochs) + Cosine Annealing |
| Regularization | Drop rate 0.3, Drop path 0.2, Label smoothing 0.1 |
| Augmentation | Mixup (α=0.4) + CutMix (α=1.0) |
| AMP | float16 |
| Hardware | 2× NVIDIA Tesla T4 (DDP) |
| Best Epoch | 13 / 50 |
| Early Stopping | Patience 7 — triggered at Epoch 20 |
| Epoch | Val Loss | Accuracy | Macro F1 |
|---|---|---|---|
| 1 | 1.0631 | 51.04% | 0.5088 |
| 2 | 0.9541 | 71.88% | 0.7154 |
| 3 | 0.8096 | 81.77% | 0.8118 |
| 5 | 0.5122 | 90.62% | 0.9033 |
| 7 | 0.4153 | 95.31% | 0.9506 |
| 10 | 0.3676 | 97.92% | 0.9777 |
| 13 | 0.3562 | 98.96% | 0.9888 |
| 20 | — | — | — (early stop) |
Training data was augmented offline from 167 real images to 3840 balanced training images (1280 per class) using a heavy Albumentations pipeline including random crops, flips, rotations, colour jitter, blur, noise, elastic transforms, perspective distortion, CoarseDropout, and CLAHE. Validation set was augmented from 41 real images to 192 balanced images (64 per class).
import torch
import timm
import torch.nn as nn
from torchvision import transforms
from PIL import Image
class SaffronVerifyModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model(
"convnext_base", pretrained=False,
num_classes=0, drop_rate=0.3, drop_path_rate=0.2
)
feat_dim = self.backbone.num_features
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Dropout(p=0.3),
nn.Linear(feat_dim, 512),
nn.GELU(),
nn.Dropout(p=0.15),
nn.Linear(512, 3),
)
def forward(self, x):
return self.head(self.backbone(x))
CLASSES = ["mogra", "lacha", "adulterated"]
# Load model
model = SaffronVerifyModel()
ckpt = torch.load("best_model.pth", map_location="cpu")
model.load_state_dict(ckpt["model_state"])
model.eval()
# Preprocess
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
img = Image.open("saffron.jpg").convert("RGB")
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
pred = logits.argmax(1).item()
print(f"Predicted class: {CLASSES[pred]}")
Apache 2.0
Base model
facebook/convnext-base-224-22k-1k