ducanhdinh/jepa_proof_boyl

BERT encoder pretrained from scratch với BYOL (Bootstrap Your Own Latent).

Augmentation strategy

Hai view được tạo bằng span masking độc lập:

Mô tả
View 1 Câu gốc với các span ngẫu nhiên bị mask
View 2 Câu gốc với các span ngẫu nhiên khác bị mask (không overlap)

Kiến trúc BYOL

View 1 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p1 ──┐
                                                                                           ├── loss = cosine(p1, z2) + cosine(p2, z1)
View 2 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p2 ──┘
View 1 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z1  (stop grad)
View 2 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z2  (stop grad)

Target update: ξ ← 0.996·ξ + 0.0040000000000000036·θ  (EMA)

Thông số huấn luyện

Tham số Giá trị
Max sequence length 256
Batch size 256
Epochs 10
Learning rate 0.0001
Projector hidden dim 2048
Projector out dim 256
Predictor hidden dim 1024
EMA decay 0.996
Max span length 5

Cách dùng

from transformers import BertModel, BertTokenizerFast
import torch

tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_boyl")
bert      = BertModel.from_pretrained("ducanhdinh/jepa_proof_boyl/encoder")

encoded = tokenizer(
    ["Hello world!", "BYOL pretraining rocks."],
    return_tensors="pt",
    padding=True,
    truncation=True,
)
with torch.no_grad():
    out     = bert(**encoded)
    cls_emb = out.last_hidden_state[:, 0, :]   # [CLS] token → (B, 768)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support