ViT-Flower: Vision Transformer for Flower Classification

Model Description

ViT-Flower is a Vision Transformer (ViT-Base-Patch16-384) model fine-tuned for flower classification. The model classifies images into 152 different flower categories with both Chinese and English names.

  • Model Architecture: Vision Transformer (ViT-Base-Patch16-384)
  • Task: Image Classification
  • Number of Classes: 152
  • Input Resolution: 384 Γ— 384 pixels
  • Framework: PyTorch + Transformers

Model Architecture

ViT-Base-Patch16-384 Backbone
β”œβ”€β”€ Patch Embedding: 16Γ—16 patches β†’ 768 dim
β”œβ”€β”€ Transformer Encoder: 12 blocks
β”‚   └── Each block: Multi-Head Self-Attention + MLP
└── CLS Token β†’ 768-dim feature

Classification Head
β”œβ”€β”€ Linear: 768 β†’ 1024
β”œβ”€β”€ BatchNorm1d: 1024
β”œβ”€β”€ GELU Activation
β”œβ”€β”€ Dropout: 0.4
└── Linear: 1024 β†’ 152

Usage

Using Transformers

from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image

# Load model and processor
model_name = "your-username/Vit-Flower"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

# Load and process image
image = Image.open("flower.jpg")
inputs = processor(images=image, return_tensors="pt")

# Inference
outputs = model(**inputs)
logits = outputs.logits
predicted_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[str(predicted_id)]

print(f"Predicted class: {predicted_label}")

Using timm + safetensors

import torch
from safetensors.torch import load_file
import timm
from PIL import Image
from torchvision import transforms

# Load model
model = timm.create_model('vit_base_patch16_384', pretrained=False, num_classes=152)
state_dict = load_file('Vit-Flower.safetensors')
model.load_state_dict(state_dict)
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Inference
image = Image.open("flower.jpg").convert('RGB')
input_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    output = model(input_tensor)
    predicted_id = output.argmax(-1).item()

Input Preprocessing

Parameter Value
Image Size 384 Γ— 384
Mean [0.5, 0.5, 0.5]
Std [0.5, 0.5, 0.5]
Format RGB

Model Files

File Description
Vit-Flower.safetensors Model weights in safetensors format
config.json Model configuration with id2label mapping
preprocessor_config.json Image preprocessing configuration
id2label_full.json Complete label mapping (category_id, chinese_name, english_name)

Label Categories (152 Classes)

Index Category ID Chinese Name English Name
0 164 η΄«εΆη«ΉθŠ‚η§‹ζ΅·ζ£ οΌˆη΄«η«Ήζ’…οΌ‰ Tradescantia pallida
1 165 ιΎ™η‰™θ‰οΌˆδ»™ιΉ€θ‰οΌ‰ Agrimonia eupatoria
2 166 η»œηŸ³οΌˆι£Žθ½¦θŒ‰θŽ‰οΌ‰ Trachelospermum jasminoides
3 167 ζΉ–εŒ—θšθ’Ύ Eriocapitella hupehensis
4 168 ζ‘ƒεΆι£Žι“ƒθ‰ Campanula persicifolia
5 169 ζ—±ι‡‘θŽ² Tropaeolum majus
6 170 η”°ι‡Žηη θœ Lysimachia arvensis
7 171 η™½θŠ±ι›ͺ果 Symphoricarpos albus
8 172 εŠε…° Chlorophytum comosum
9 173 ε•€ι…’θŠ± Humulus lupulus
... ... ... ...
48 18 ηŽ«η‘° Rosa rugosa
... ... ... ...
151 1891 三色堇 Viola Γ— wittrockiana

See id2label_full.json for complete 152 class mappings.

Training Details

Parameter Value
Epochs 40
Learning Rate 1.5e-4
Optimizer AdamW
Weight Decay 8e-5
Batch Size 32
Warmup Epochs 6
Frozen Blocks 10 of 12
Loss Function FocalCrossEntropyLoss (Ξ±=0.25, Ξ³=2)
Label Smoothing 0.1

Limitations

  • Input images must be RGB format
  • Optimal performance on flower images similar to training distribution
  • Model expects 384Γ—384 input resolution

Citation

@article{dosovitskiy2021vit,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and others},
  journal={ICLR},
  year={2021}
}

License

Apache 2.0

Downloads last month
14
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support