ducanhdinh/jepa_proof_vicreg
BERT encoder pretrained from scratch với VICReg (Variance-Invariance-Covariance Regularization).
Hai masked text views được encode bởi một BERT encoder dùng chung, sau đó đưa qua expander MLP. VICReg kết hợp 3 loss terms để căn chỉnh các views và ngăn feature collapse mà không cần negative samples:
| Loss term | Hệ số | Mô tả |
|---|---|---|
| Invariance | 25.0 |
MSE giữa z1 và z2 (căn chỉnh hai views) |
| Variance | 25.0 |
Giữ std của mỗi chiều ≥ 1 (chống collapse) |
| Covariance | 1.0 |
Decorrelate các chiều embedding |
Kiến trúc
Text → BERT (mean-pool) → z ∈ R^768 → Expander MLP → z' ∈ R^3072
↑ VICReg loss áp dụng tại đây
Expander gồm 3 lớp Linear-BatchNorm-ReLU (dim = 3072).
Thông số huấn luyện
| Tham số | Giá trị |
|---|---|
| Max sequence length | 256 |
| Batch size | 256 |
| Epochs | 10 |
| Learning rate | 0.0001 |
| Expander dim | 3072 |
| Max span length (masking) | 5 |
| sim_coeff | 25.0 |
| std_coeff | 25.0 |
| cov_coeff | 1.0 |
Cách dùng — BERT encoder (feature extraction)
from transformers import BertModel, BertTokenizerFast
import torch
tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_vicreg")
bert = BertModel.from_pretrained("ducanhdinh/jepa_proof_vicreg/encoder")
encoded = tokenizer(
["Hello world!", "VICReg is great."],
return_tensors="pt",
padding=True,
truncation=True,
)
with torch.no_grad():
out = bert(**encoded)
hidden = out.last_hidden_state # (B, T, 768)
mask = encoded["attention_mask"].unsqueeze(-1).float()
emb = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1) # mean-pool → (B, 768)
Cách dùng — Full model (encoder + expander)
import torch
from transformers import BertTokenizerFast
# Load weights thủ công
from text_vicreg import TextVICReg, VICRegPretrainConfig
cfg = VICRegPretrainConfig()
model = TextVICReg(cfg)
state = torch.load(
hf_hub_download("ducanhdinh/jepa_proof_vicreg", "pytorch_model.bin"),
map_location="cpu",
)
model.load_state_dict(state)
model.eval()
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support