ducanhdinh/jepa_proof_boyl
BERT encoder pretrained from scratch với BYOL (Bootstrap Your Own Latent).
Augmentation strategy
Hai view được tạo bằng span masking độc lập:
| Mô tả | |
|---|---|
| View 1 | Câu gốc với các span ngẫu nhiên bị mask |
| View 2 | Câu gốc với các span ngẫu nhiên khác bị mask (không overlap) |
Kiến trúc BYOL
View 1 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p1 ──┐
├── loss = cosine(p1, z2) + cosine(p2, z1)
View 2 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p2 ──┘
View 1 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z1 (stop grad)
View 2 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z2 (stop grad)
Target update: ξ ← 0.996·ξ + 0.0040000000000000036·θ (EMA)
Thông số huấn luyện
| Tham số | Giá trị |
|---|---|
| Max sequence length | 256 |
| Batch size | 256 |
| Epochs | 10 |
| Learning rate | 0.0001 |
| Projector hidden dim | 2048 |
| Projector out dim | 256 |
| Predictor hidden dim | 1024 |
| EMA decay | 0.996 |
| Max span length | 5 |
Cách dùng
from transformers import BertModel, BertTokenizerFast
import torch
tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_boyl")
bert = BertModel.from_pretrained("ducanhdinh/jepa_proof_boyl/encoder")
encoded = tokenizer(
["Hello world!", "BYOL pretraining rocks."],
return_tensors="pt",
padding=True,
truncation=True,
)
with torch.no_grad():
out = bert(**encoded)
cls_emb = out.last_hidden_state[:, 0, :] # [CLS] token → (B, 768)
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support