ducanhdinh/jepa_proof_boyl_replace

BERT encoder pretrained from scratch với BYOL + Lexical Substitution augmentation.

Augmentation strategy

Thay vì span masking, model này dùng lexical substitution để tạo view 2:

Mô tả
View 1 Câu gốc (không thay đổi)
View 2 15–20% token ngẫu nhiên được thay bằng token ngữ nghĩa gần, dự đoán bởi BERT MLM (top-5, loại trừ token gốc)

Quá trình tạo view 2 được thực hiện offline 1 lần trước khi train.

Kiến trúc BYOL

View 1 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p1 ──┐
                                                                                           ├── loss = cosine(p1, z2) + cosine(p2, z1)
View 2 ──► Online Encoder (θ) ──► Online Projector (θ) ──► Online Predictor (θ) ──► p2 ──┘
View 1 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z1  (stop grad)
View 2 ──► Target Encoder (ξ) ──► Target Projector (ξ) ──► z2  (stop grad)

Target update: ξ ← 0.996·ξ + 0.0040000000000000036·θ  (EMA)

Thông số huấn luyện

Tham số Giá trị
Max sequence length 256
Batch size 256
Epochs 10
Learning rate 0.0001
Projector hidden dim 2048
Projector out dim 256
Predictor hidden dim 1024
EMA decay 0.996
Mask ratio (lexsubst) 0.15–0.2
Top-k candidates 5

Cách dùng — Online BERT encoder (feature extraction)

from transformers import BertModel, BertTokenizerFast
import torch

tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_boyl_replace")
bert      = BertModel.from_pretrained("ducanhdinh/jepa_proof_boyl_replace/encoder")

encoded = tokenizer(
    ["Hello world!", "BYOL with lexical substitution."],
    return_tensors="pt",
    padding=True,
    truncation=True,
)
with torch.no_grad():
    out = bert(**encoded)
    cls_emb = out.last_hidden_state[:, 0, :]   # [CLS] token → (B, 768)

Cách dùng — Full model

import torch
from text_byol_replace import TextBYOLReplace, BYOLReplacePretrainConfig

cfg   = BYOLReplacePretrainConfig()
model = TextBYOLReplace(cfg)
state = torch.load(
    hf_hub_download("ducanhdinh/jepa_proof_boyl_replace", "pytorch_model.bin"),
    map_location="cpu",
)
model.load_state_dict(state)
model.eval()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support