Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import torch | |
| from safetensors.torch import load_file | |
| from rich.console import Console | |
| from rich.table import Table | |
| from transformers import AutoModelForMaskedLM, AutoConfig, AutoModel | |
| from e1_fastplms.modeling_e1 import E1ForMaskedLM, E1Config, E1Model | |
| def load_weights(path, cast_fp32=True): | |
| assert os.path.exists(path), f"File {path} not found." | |
| if path.endswith(".safetensors"): | |
| sd = load_file(path) | |
| elif path.endswith(".pth") or path.endswith(".pt"): | |
| sd = torch.load(path, map_location="cpu", weights_only=True) | |
| if isinstance(sd, dict) and "state_dict" in sd: | |
| sd = sd["state_dict"] | |
| elif isinstance(sd, dict) and "model" in sd: | |
| sd = sd["model"] | |
| else: | |
| try: | |
| sd = load_file(path) | |
| except Exception: | |
| sd = torch.load(path, map_location="cpu", weights_only=True) | |
| if cast_fp32: | |
| return {k: v.float() if isinstance(v, torch.Tensor) else v for k, v in sd.items()} | |
| return sd | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--file1", type=str, default=None) | |
| parser.add_argument("--files", type=str, nargs="+", default=None) | |
| parser.add_argument("--strict", action="store_true") | |
| parser.add_argument("--assert_exact", action="store_true") | |
| args = parser.parse_args() | |
| model = E1ForMaskedLM.from_pretrained('Profluent-Bio/E1-150m', dtype=torch.float32).eval() | |
| torch.save(model.state_dict(), 'official.pth') | |
| config = AutoConfig.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True) | |
| model1 = AutoModel.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval() | |
| torch.save(model1.state_dict(), 'load_from_pretrained_1.pth') | |
| model2 = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval() | |
| torch.save(model2.state_dict(), 'load_from_pretrained_2.pth') | |
| if args.file1 is None: | |
| args.file1 = 'official.pth' | |
| if args.files is None: | |
| args.files = ['load_from_pretrained_1.pth', 'load_from_pretrained_2.pth', 'old.safetensors'] | |
| paths = [args.file1] + args.files | |
| sds = [load_weights(p, cast_fp32=not args.strict) for p in paths] | |
| all_keys = sorted(set().union(*(sd.keys() for sd in sds))) | |
| strict_mismatches = [] | |
| console = Console() | |
| table = Table(title=f"Weights Comparison (Reference: {os.path.basename(paths[0])})") | |
| table.add_column("Tensor Name", style="cyan", no_wrap=True) | |
| for p in paths[1:]: | |
| table.add_column(f"{os.path.basename(p)} == Ref", justify="center") | |
| sd1 = sds[0] | |
| for k in all_keys: | |
| row = [k] | |
| has_ref = k in sd1 | |
| ref_w = sd1[k] if has_ref else None | |
| for sd in sds[1:]: | |
| has_other = k in sd | |
| other_w = sd[k] if has_other else None | |
| if not has_ref or not has_other: | |
| if not has_ref and not has_other: | |
| row.append("[dim]β[/dim]") | |
| else: | |
| row.append("[red]β[/red]") | |
| else: | |
| # Both present, compare shapes and MSE | |
| assert isinstance(ref_w, torch.Tensor), f"Weight {k} in reference is not a tensor." | |
| assert isinstance(other_w, torch.Tensor), f"Weight {k} in comparison file is not a tensor." | |
| if ref_w.shape != other_w.shape: | |
| row.append("[red]β (Shape)[/red]") | |
| else: | |
| if args.strict: | |
| if torch.equal(ref_w, other_w): | |
| row.append("[green]β[/green]") | |
| else: | |
| mse = torch.mean((ref_w.float() - other_w.float())**2).item() | |
| row.append(f"[red]β (Strict, MSE: {mse:.2e})[/red]") | |
| strict_mismatches.append(k) | |
| else: | |
| mse = torch.mean((ref_w - other_w)**2).item() | |
| if mse == 0: | |
| row.append("[green]β[/green]") | |
| else: | |
| row.append(f"[red]β (MSE: {mse:.2e})[/red]") | |
| table.add_row(*row) | |
| console.print(table) | |
| if args.strict and args.assert_exact: | |
| assert len(strict_mismatches) == 0, ( | |
| f"Found {len(strict_mismatches)} strict mismatches. " | |
| f"First mismatches: {strict_mismatches[:10]}" | |
| ) | |
| if __name__ == "__main__": | |
| main() |