Instructions to use max044/vl-jepa-custom with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use max044/vl-jepa-custom with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("max044/vl-jepa-custom", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from vljepa.config import Config | |
| from vljepa.models import VLJepa | |
| from vljepa.utils import nms | |
| def load_model(checkpoint_path, device="cpu"): | |
| config = Config() | |
| config.device = device | |
| model = VLJepa(config) | |
| print(f"Loading weights from {checkpoint_path}...") | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) | |
| model.predictor.load_state_dict(checkpoint["predictor_state_dict"]) | |
| model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"]) | |
| model.eval() | |
| return model, config | |
| def extract_frames(video_path, num_frames=16): | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames <= 0: | |
| return [] | |
| indices = np.linspace(0, total_frames - 1, num_frames).astype(int) | |
| frames = [] | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame) | |
| cap.release() | |
| return frames | |
| def main(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| checkpoint_path = "best.pth" | |
| video_path = "sample_video.mp4" # Replace with a real video path | |
| query = "a person is opening a door" | |
| model, config = load_model(checkpoint_path, device) | |
| # This is a simplified inference demonstration. | |
| # In a real scenario, you would use a sliding window approach as seen in infer.py | |
| print(f"Ready for inference on {device}.") | |
| print(f"Model architecture: {config.clip_model} + {config.predictor_model} (LoRA) + {config.text_model}") | |
| # Example Tokenization | |
| query_tokens = model.query_encoder.tokenize([query], device=device) | |
| # Example Text Encoding | |
| with torch.no_grad(): | |
| text_embedding = model.encode_text([query], device=device) | |
| print(f"Query: '{query}'") | |
| print(f"Text embedding shape: {text_embedding.shape}") | |
| print("\nTo perform full temporal localization, use the infer.py script which implements sliding window and NMS.") | |
| if __name__ == "__main__": | |
| main() | |