ducanhdinh/jepa_proof_barlow_twins_replace

BERT encoder pretrained from scratch với Barlow Twins + Lexical Substitution augmentation.

Augmentation strategy

Thay vì span masking, model này dùng lexical substitution để tạo view 2:

Mô tả
View 1 Câu gốc (không thay đổi)
View 2 15–20% token ngẫu nhiên được thay bằng token ngữ nghĩa gần, dự đoán bởi BERT MLM (top-5, loại trừ token gốc)

Quá trình tạo view 2 được thực hiện offline 1 lần trước khi train.

Kiến trúc Barlow Twins

View 1 ──► Encoder (θ) ──► Projector (θ) ──► z1  ──┐
                                                     ├──► Cross-correlation C = Z1ᵀZ2 / N  ──► Loss
View 2 ──► Encoder (θ) ──► Projector (θ) ──► z2  ──┘

Loss = Σ(C_ii - 1)²  +  λ · Σ_{i≠j} C_ij²
         on-diagonal       off-diagonal (redundancy reduction)

Encoder và Projector dùng shared weights (không có target network).

Thông số huấn luyện

Tham số Giá trị
Max sequence length 256
Batch size 256
Epochs 10
Learning rate 0.0001
Projector hidden dim 2048
Projector out dim 8192
Off-diagonal coeff (λ) 0.005
Mask ratio (lexsubst) 0.15–0.2
Top-k candidates 5

Cách dùng — BERT encoder (feature extraction)

from transformers import BertModel, BertTokenizerFast
import torch

tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_barlow_twins_replace")
bert      = BertModel.from_pretrained("ducanhdinh/jepa_proof_barlow_twins_replace/encoder")

encoded = tokenizer(
    ["Hello world!", "Barlow Twins with lexical substitution."],
    return_tensors="pt",
    padding=True,
    truncation=True,
)
with torch.no_grad():
    out     = bert(**encoded)
    cls_emb = out.last_hidden_state[:, 0, :]   # [CLS] token → (B, 768)

Cách dùng — Full model

import torch
from text_barlow_twins_replace import TextBarlowTwinsReplace, BarlowTwinsReplacePretrainConfig

cfg   = BarlowTwinsReplacePretrainConfig()
model = TextBarlowTwinsReplace(cfg)
state = torch.load(
    hf_hub_download("ducanhdinh/jepa_proof_barlow_twins_replace", "pytorch_model.bin"),
    map_location="cpu",
)
model.load_state_dict(state)
model.eval()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support