ViT-Flower: Vision Transformer for Flower Classification
Model Description
ViT-Flower is a Vision Transformer (ViT-Base-Patch16-384) model fine-tuned for flower classification. The model classifies images into 152 different flower categories with both Chinese and English names.
- Model Architecture: Vision Transformer (ViT-Base-Patch16-384)
- Task: Image Classification
- Number of Classes: 152
- Input Resolution: 384 Γ 384 pixels
- Framework: PyTorch + Transformers
Model Architecture
ViT-Base-Patch16-384 Backbone
βββ Patch Embedding: 16Γ16 patches β 768 dim
βββ Transformer Encoder: 12 blocks
β βββ Each block: Multi-Head Self-Attention + MLP
βββ CLS Token β 768-dim feature
Classification Head
βββ Linear: 768 β 1024
βββ BatchNorm1d: 1024
βββ GELU Activation
βββ Dropout: 0.4
βββ Linear: 1024 β 152
Usage
Using Transformers
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
model_name = "your-username/Vit-Flower"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
image = Image.open("flower.jpg")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[str(predicted_id)]
print(f"Predicted class: {predicted_label}")
Using timm + safetensors
import torch
from safetensors.torch import load_file
import timm
from PIL import Image
from torchvision import transforms
model = timm.create_model('vit_base_patch16_384', pretrained=False, num_classes=152)
state_dict = load_file('Vit-Flower.safetensors')
model.load_state_dict(state_dict)
model.eval()
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = Image.open("flower.jpg").convert('RGB')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
predicted_id = output.argmax(-1).item()
Input Preprocessing
| Parameter |
Value |
| Image Size |
384 Γ 384 |
| Mean |
[0.5, 0.5, 0.5] |
| Std |
[0.5, 0.5, 0.5] |
| Format |
RGB |
Model Files
| File |
Description |
Vit-Flower.safetensors |
Model weights in safetensors format |
config.json |
Model configuration with id2label mapping |
preprocessor_config.json |
Image preprocessing configuration |
id2label_full.json |
Complete label mapping (category_id, chinese_name, english_name) |
Label Categories (152 Classes)
| Index |
Category ID |
Chinese Name |
English Name |
| 0 |
164 |
η΄«εΆη«Ήθη§ζ΅·ζ£ οΌη΄«η«Ήζ’
οΌ |
Tradescantia pallida |
| 1 |
165 |
ιΎηθοΌδ»ιΉ€θοΌ |
Agrimonia eupatoria |
| 2 |
166 |
η»η³οΌι£θ½¦θθοΌ |
Trachelospermum jasminoides |
| 3 |
167 |
ζΉεθθΎ |
Eriocapitella hupehensis |
| 4 |
168 |
ζ‘εΆι£ιθ |
Campanula persicifolia |
| 5 |
169 |
ζ±ιθ² |
Tropaeolum majus |
| 6 |
170 |
η°ιηη θ |
Lysimachia arvensis |
| 7 |
171 |
η½θ±ιͺζ |
Symphoricarpos albus |
| 8 |
172 |
εε
° |
Chlorophytum comosum |
| 9 |
173 |
ε€ι
θ± |
Humulus lupulus |
| ... |
... |
... |
... |
| 48 |
18 |
η«η° |
Rosa rugosa |
| ... |
... |
... |
... |
| 151 |
1891 |
δΈθ²ε |
Viola Γ wittrockiana |
See id2label_full.json for complete 152 class mappings.
Training Details
| Parameter |
Value |
| Epochs |
40 |
| Learning Rate |
1.5e-4 |
| Optimizer |
AdamW |
| Weight Decay |
8e-5 |
| Batch Size |
32 |
| Warmup Epochs |
6 |
| Frozen Blocks |
10 of 12 |
| Loss Function |
FocalCrossEntropyLoss (Ξ±=0.25, Ξ³=2) |
| Label Smoothing |
0.1 |
Limitations
- Input images must be RGB format
- Optimal performance on flower images similar to training distribution
- Model expects 384Γ384 input resolution
Citation
@article{dosovitskiy2021vit,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and others},
journal={ICLR},
year={2021}
}
License
Apache 2.0