ducanhdinh/jepa_proof_boyl_replace
BERT encoder pretrained from scratch với BYOL + Lexical Substitution augmentation.
Augmentation strategy
Thay vì span masking, model này dùng lexical substitution để tạo view 2:
| Mô tả | |
|---|---|
| View 1 | Câu gốc (không thay đổi) |
| View 2 | 15–20% token ngẫu nhiên được thay bằng token ngữ nghĩa gần, dự đoán bởi BERT MLM (top-5, loại trừ token gốc) |
Quá trình tạo view 2 được thực hiện offline 1 lần trước khi train.
Kiến trúc BYOL
View 1 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p1 ──┐
├── loss = cosine(p1, z2) + cosine(p2, z1)
View 2 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p2 ──┘
View 1 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z1 (stop grad)
View 2 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z2 (stop grad)
Target update: ξ ← 0.996·ξ + 0.0040000000000000036·θ (EMA)
Thông số huấn luyện
| Tham số | Giá trị |
|---|---|
| Max sequence length | 256 |
| Batch size | 256 |
| Epochs | 10 |
| Learning rate | 0.0001 |
| Projector hidden dim | 2048 |
| Projector out dim | 256 |
| Predictor hidden dim | 1024 |
| EMA decay | 0.996 |
| Mask ratio (lexsubst) | 0.15–0.2 |
| Top-k candidates | 5 |
Cách dùng — Online BERT encoder (feature extraction)
from transformers import BertModel, BertTokenizerFast
import torch
tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_boyl_replace")
bert = BertModel.from_pretrained("ducanhdinh/jepa_proof_boyl_replace/encoder")
encoded = tokenizer(
["Hello world!", "BYOL with lexical substitution."],
return_tensors="pt",
padding=True,
truncation=True,
)
with torch.no_grad():
out = bert(**encoded)
cls_emb = out.last_hidden_state[:, 0, :] # [CLS] token → (B, 768)
Cách dùng — Full model
import torch
from text_byol_replace import TextBYOLReplace, BYOLReplacePretrainConfig
cfg = BYOLReplacePretrainConfig()
model = TextBYOLReplace(cfg)
state = torch.load(
hf_hub_download("ducanhdinh/jepa_proof_boyl_replace", "pytorch_model.bin"),
map_location="cpu",
)
model.load_state_dict(state)
model.eval()
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support