Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import entrypoint_setup | |
| import random | |
| import tempfile | |
| import os | |
| import torch | |
| from esm2.modeling_fastesm import FastEsmForMaskedLM | |
| from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM | |
| from e1_fastplms.modeling_e1 import E1ForMaskedLM | |
| from dplm_fastplms.modeling_dplm import DPLMForMaskedLM | |
| from dplm2_fastplms.modeling_dplm2 import ( | |
| DPLM2ForMaskedLM, | |
| _has_packed_multimodal_layout, | |
| _normalize_dplm2_input_ids, | |
| ) | |
| from embedding_mixin import parse_fasta | |
| CANONICAL_AAS = "ACDEFGHIKLMNPQRSTVWY" | |
| SEED = 42 | |
| DEFAULT_BATCH_SIZE = 4 | |
| MAX_EMBED_LEN = 128 # fixed pad length used to keep max_seqlen identical across runs | |
| # (display_name, model_class, hf_path, use_model_tokenizer) | |
| MODEL_CONFIGS = [ | |
| ("ESM2", FastEsmForMaskedLM, "Synthyra/ESM2-8M", True), | |
| ("ESM++", ESMplusplusForMaskedLM, "Synthyra/ESMplusplus_small", True), | |
| ("E1", E1ForMaskedLM, "Synthyra/Profluent-E1-150M", False), | |
| ("DPLM", DPLMForMaskedLM, "Synthyra/DPLM-150M", True), | |
| ("DPLM2", DPLM2ForMaskedLM, "Synthyra/DPLM2-150M", True), | |
| ] | |
| def test_parse_fasta() -> None: | |
| """Test parse_fasta with single-line and multi-line sequences.""" | |
| fasta_content = ( | |
| ">seq1 a simple protein\n" | |
| "MKTLLLTLVVVTIVCLDLGYT\n" | |
| ">seq2 multi-line sequence\n" | |
| "ACDEFGHIKL\n" | |
| "MNPQRSTVWY\n" | |
| ">seq3 another entry\n" | |
| "MALWMRLLPLLALL\n" | |
| ) | |
| expected = [ | |
| "MKTLLLTLVVVTIVCLDLGYT", | |
| "ACDEFGHIKLMNPQRSTVWY", | |
| "MALWMRLLPLLALL", | |
| ] | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) as f: | |
| f.write(fasta_content) | |
| tmp_path = f.name | |
| parsed = parse_fasta(tmp_path) | |
| os.unlink(tmp_path) | |
| assert parsed == expected, f"parse_fasta mismatch:\n got: {parsed}\n expected: {expected}" | |
| print("test_parse_fasta: OK") | |
| class FixedLengthTokenizer: | |
| """Wraps a tokenizer so every call pads to exactly MAX_EMBED_LEN tokens. | |
| Both batch=1 and batch=N therefore receive tensors of the same shape, | |
| keeping max_seqlen_in_batch identical and eliminating floating-point | |
| variability from different softmax vector lengths / flash-attention tile sizes. | |
| """ | |
| def __init__(self, tokenizer, max_length: int = MAX_EMBED_LEN): | |
| self._tok = tokenizer | |
| self.max_length = max_length | |
| def __call__(self, sequences, **kwargs): | |
| return self._tok( | |
| sequences, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=self.max_length, | |
| truncation=True, | |
| ) | |
| def random_sequences(n: int, min_len: int = 8, max_len: int = 64) -> list[str]: | |
| """Variable-length sequences; used for the NaN test.""" | |
| return [ | |
| "M" + "".join(random.choices(CANONICAL_AAS, k=random.randint(min_len, max_len))) | |
| for _ in range(n) | |
| ] | |
| def random_sequences_fixed_len(n: int, length: int = 64) -> list[str]: | |
| """Fixed-length sequences; used for the match test with E1 (sequence mode).""" | |
| return [ | |
| "M" + "".join(random.choices(CANONICAL_AAS, k=length - 1)) | |
| for _ in range(n) | |
| ] | |
| def assert_no_nan(embeddings: dict[str, torch.Tensor], label: str) -> None: | |
| for seq, emb in embeddings.items(): | |
| assert not torch.isnan(emb).any(), ( | |
| f"[{label}] NaN found in embedding for sequence '{seq[:20]}...'" | |
| ) | |
| def assert_embeddings_match( | |
| a: dict[str, torch.Tensor], | |
| b: dict[str, torch.Tensor], | |
| label: str, | |
| atol: float = 5e-3, | |
| ) -> None: | |
| """Compare real-token embeddings from two runs. | |
| full_embeddings=True already strips padding via emb[mask.bool()], so both | |
| dicts contain only non-pad token rows and the comparison is over those rows. | |
| """ | |
| assert set(a) == set(b), f"[{label}] Key sets differ between batch and single runs" | |
| for seq in a: | |
| ea, eb = a[seq].float(), b[seq].float() | |
| assert ea.shape == eb.shape, ( | |
| f"[{label}] Shape mismatch for '{seq[:20]}': {ea.shape} vs {eb.shape}" | |
| ) | |
| max_diff = (ea - eb).abs().max().item() | |
| assert max_diff <= atol, ( | |
| f"[{label}] Max abs diff {max_diff:.5f} > {atol} for '{seq[:20]}'" | |
| ) | |
| def test_dplm2_multimodal_layout_guard() -> None: | |
| plain_sequence_type_ids = torch.tensor([ | |
| [1, 1, 1, 1, 1, 1, 0, 2], | |
| [1, 1, 1, 1, 1, 0, 2, 2], | |
| ]) | |
| packed_multimodal_type_ids = torch.tensor([ | |
| [1, 1, 1, 2, 0, 0, 0, 2], | |
| [1, 1, 2, 2, 0, 0, 2, 2], | |
| ]) | |
| mismatched_multimodal_type_ids = torch.tensor([ | |
| [1, 1, 1, 2, 0, 0, 2, 2], | |
| ]) | |
| assert not _has_packed_multimodal_layout(plain_sequence_type_ids, aa_type=1, struct_type=0, pad_type=2) | |
| assert _has_packed_multimodal_layout(packed_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2) | |
| assert not _has_packed_multimodal_layout(mismatched_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2) | |
| print("test_dplm2_multimodal_layout_guard: OK") | |
| def test_dplm2_special_token_normalization() -> None: | |
| input_ids = torch.tensor([[8231, 5, 23, 13, 8229, 1, 8232, -100]]) | |
| normalized_input_ids = _normalize_dplm2_input_ids(input_ids, vocab_size=8229) | |
| expected = torch.tensor([[0, 5, 23, 13, 2, 1, 32, -100]]) | |
| assert torch.equal(normalized_input_ids, expected), ( | |
| f"DPLM2 special-token normalization mismatch:\n" | |
| f" got: {normalized_input_ids.tolist()}\n" | |
| f" expected: {expected.tolist()}" | |
| ) | |
| print("test_dplm2_special_token_normalization: OK") | |
| def test_model(name: str, model_cls, model_path: str, use_model_tokenizer: bool, batch_size: int) -> None: | |
| print(f"\n--- {name} ({model_path}) ---") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model_cls.from_pretrained( | |
| model_path, | |
| dtype=torch.bfloat16, | |
| device_map=device, | |
| trust_remote_code=True, | |
| ).eval() | |
| if use_model_tokenizer: | |
| # FixedLengthTokenizer pads every batch to MAX_EMBED_LEN regardless of | |
| # actual sequence lengths, so batch=1 and batch=N see the same tensor | |
| # shape and produce numerically identical real-token outputs. | |
| tokenizer = FixedLengthTokenizer(model.tokenizer) | |
| sequences = random_sequences(n=8) # variable lengths, all padded to MAX_EMBED_LEN | |
| else: | |
| # E1 (sequence mode): control padding length via fixed-length sequences | |
| # so max_seqlen_in_batch is the same in every forward call. | |
| tokenizer = None | |
| sequences = random_sequences_fixed_len(n=8) # fixed length, no padding variability | |
| nan_kwargs = dict( | |
| tokenizer=tokenizer, | |
| full_embeddings=True, # extracts only real (non-pad) token rows via emb[mask.bool()] | |
| embed_dtype=torch.bfloat16, | |
| save=False, | |
| ) | |
| # NaN test ---------------------------------------------------------------- | |
| # Run in bfloat16 to match the real-world user scenario. | |
| # batch_size > 1 with padding present must produce no NaN in real-token rows. | |
| nan_embs = model.embed_dataset(sequences=sequences, batch_size=batch_size, **nan_kwargs) | |
| assert_no_nan(nan_embs, f"{name} NaN check batch_size={batch_size}") | |
| shapes = [tuple(e.shape) for e in list(nan_embs.values())[:3]] | |
| print(f" NaN check batch_size={batch_size}: OK sample shapes={shapes}") | |
| # Match test (tokenizer / SDPA models only) -------------------------------- | |
| # The NaN fix only touches SDPA backends; E1 uses flash varlen which | |
| # inherently unpads and is unaffected. Flash varlen is also NOT | |
| # bit-deterministic across different batch sizes (different numbers of | |
| # packed query blocks → different online-softmax accumulation order), so | |
| # a tight match test for E1 is not meaningful. | |
| # | |
| # For SDPA models we cast to float32: bfloat16 CUBLAS selects different | |
| # mat-mul algorithms for batch=1 vs batch=N (simple vs batched GEMM), | |
| # producing 1-ULP differences. Float32 differences are < 1e-3. | |
| if not use_model_tokenizer: | |
| return | |
| model.to(torch.float32) | |
| batch_embs = model.embed_dataset( | |
| sequences=sequences, batch_size=batch_size, | |
| tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False, | |
| ) | |
| single_embs = model.embed_dataset( | |
| sequences=sequences, batch_size=1, | |
| tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False, | |
| ) | |
| assert_no_nan(batch_embs, f"{name} match test batch_size={batch_size}") | |
| assert_no_nan(single_embs, f"{name} match test batch_size=1") | |
| assert_embeddings_match(batch_embs, single_embs, name) | |
| print(f" Match test batch_size={batch_size} vs 1: OK (non-pad tokens only)") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Test embed_dataset produces no NaN with batch_size > 1.") | |
| parser.add_argument("--models", nargs="+", default=["ESM2", "ESM++", "E1", "DPLM", "DPLM2"]) | |
| parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE) | |
| args = parser.parse_args() | |
| random.seed(SEED) | |
| test_parse_fasta() | |
| test_dplm2_multimodal_layout_guard() | |
| test_dplm2_special_token_normalization() | |
| valid_names = {cfg[0] for cfg in MODEL_CONFIGS} | |
| for name in args.models: | |
| assert name in valid_names, f"Unknown model '{name}'. Choose from {sorted(valid_names)}" | |
| configs_by_name = {cfg[0]: cfg for cfg in MODEL_CONFIGS} | |
| for model_name in args.models: | |
| name, model_cls, model_path, use_model_tokenizer = configs_by_name[model_name] | |
| test_model(name, model_cls, model_path, use_model_tokenizer, args.batch_size) | |
| print("\nAll tests passed!") | |