Visual Document Retrieval
Transformers
Safetensors
ColPali
multilingual
colvec1
feature-extraction
text
image
video
multimodal-embedding
vidore
colqwen3_5
multilingual-embedding
custom_code
Instructions to use webAI-Official/webAI-ColVec1-4b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use webAI-Official/webAI-ColVec1-4b with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("webAI-Official/webAI-ColVec1-4b", trust_remote_code=True, dtype="auto") - ColPali
How to use webAI-Official/webAI-ColVec1-4b with ColPali:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
| """ | |
| ColVec1 - ColVec1 retrieval wrapper for late interaction. | |
| """ | |
| import glob | |
| import json | |
| import os | |
| from typing import ClassVar, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForImageTextToText, PreTrainedModel | |
| from .configuration_colvec1 import ColVec1Config | |
| class ColVec1PreTrainedModel(PreTrainedModel): | |
| """Base class for ColVec1 models.""" | |
| config_class = ColVec1Config | |
| base_model_prefix = "colvec1" | |
| supports_gradient_checkpointing = True | |
| _tied_weights_keys: ClassVar[List[str]] = [] | |
| class ColVec1(ColVec1PreTrainedModel): | |
| """ | |
| Retrieval model wrapper for ColVec1 checkpoints. | |
| It loads the upstream model with `AutoModelForImageTextToText`, then adds | |
| a projection head to produce L2-normalized retrieval embeddings. | |
| """ | |
| main_input_name: ClassVar[str] = "input_ids" | |
| def __init__(self, config: ColVec1Config): | |
| super().__init__(config) | |
| self.config = config | |
| self.vlm = None | |
| self.embedding_proj_layer = nn.Linear(config.text_hidden_size, config.embed_dim) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| kwargs.pop("output_hidden_states", None) | |
| kwargs.pop("return_dict", None) | |
| vlm_outputs = self.vlm( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| pixel_values=pixel_values, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| **kwargs, | |
| ) | |
| if hasattr(vlm_outputs, "hidden_states") and vlm_outputs.hidden_states is not None: | |
| last_hidden_states = vlm_outputs.hidden_states[-1] | |
| elif hasattr(vlm_outputs, "last_hidden_state"): | |
| last_hidden_states = vlm_outputs.last_hidden_state | |
| else: | |
| last_hidden_states = vlm_outputs[0] | |
| embeddings = self.embedding_proj_layer( | |
| last_hidden_states.to(self.embedding_proj_layer.weight.dtype) | |
| ) | |
| embeddings = nn.functional.normalize(embeddings, p=2, dim=-1) | |
| if attention_mask is not None: | |
| embeddings = embeddings * attention_mask.unsqueeze(-1) | |
| return embeddings | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: str, | |
| embed_dim: int = 128, | |
| torch_dtype: torch.dtype = None, | |
| device_map: str = None, | |
| attn_impl: str = None, | |
| **kwargs, | |
| ): | |
| # AutoModel may rename torch_dtype -> dtype in newer transformers | |
| if torch_dtype is None: | |
| torch_dtype = kwargs.pop("dtype", None) | |
| # Pop config early so we can inspect model_type for merged-repo detection. | |
| # When called via AutoModel.from_pretrained, transformers resolves the config | |
| # and passes it here as a kwarg; | |
| config = kwargs.pop("config", None) | |
| if config is not None and hasattr(config, "embed_dim"): | |
| embed_dim = config.embed_dim | |
| # Detect a merged ColVec1 repo using three strategies in order: | |
| # 1. config object already provided (Hub path via AutoModel dispatch) | |
| # 2. local config.json on disk (direct local-path usage) | |
| # 3. AutoConfig.from_pretrained (direct Hub ID usage without AutoModel) | |
| _is_merged = ( | |
| config is not None | |
| and getattr(config, "model_type", None) == "colvec1" | |
| ) | |
| if not _is_merged: | |
| config_path = os.path.join(pretrained_model_name_or_path, "config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path) as f: | |
| raw = json.load(f) | |
| _is_merged = raw.get("model_type") == "colvec1" | |
| else: | |
| # Remote Hub ID: fetch the config to check model_type. | |
| from transformers import AutoConfig | |
| try: | |
| hub_config = AutoConfig.from_pretrained( | |
| pretrained_model_name_or_path, | |
| trust_remote_code=kwargs.get("trust_remote_code", True), | |
| ) | |
| _is_merged = getattr(hub_config, "model_type", None) == "colvec1" | |
| except Exception: | |
| pass | |
| if _is_merged: | |
| return cls._load_merged( | |
| pretrained_model_name_or_path, | |
| torch_dtype=torch_dtype, | |
| device_map=device_map, | |
| attn_impl=attn_impl, | |
| **kwargs, | |
| ) | |
| # --- From-scratch path: load a raw Qwen3.5 VLM and wrap it --- | |
| # (config was already popped above; rest of the method is unchanged) | |
| vlm_kwargs = {"trust_remote_code": kwargs.pop("trust_remote_code", True)} | |
| if torch_dtype is not None: | |
| vlm_kwargs["torch_dtype"] = torch_dtype | |
| if device_map is not None: | |
| vlm_kwargs["device_map"] = device_map | |
| if attn_impl is not None: | |
| vlm_kwargs["attn_implementation"] = attn_impl | |
| if "quantization_config" in kwargs: | |
| vlm_kwargs["quantization_config"] = kwargs.pop("quantization_config") | |
| vlm = AutoModelForImageTextToText.from_pretrained(pretrained_model_name_or_path, **vlm_kwargs) | |
| if hasattr(vlm.config, "text_config") and hasattr(vlm.config.text_config, "hidden_size"): | |
| text_hidden_size = vlm.config.text_config.hidden_size | |
| else: | |
| text_hidden_size = getattr(vlm.config, "hidden_size", 2560) | |
| model_config = ColVec1Config( | |
| embed_dim=embed_dim, | |
| text_hidden_size=text_hidden_size, | |
| padding_side="left", | |
| ) | |
| model = cls(model_config) | |
| model.vlm = vlm | |
| model.embedding_proj_layer = nn.Linear(model_config.text_hidden_size, model_config.embed_dim) | |
| if torch_dtype is not None: | |
| model.embedding_proj_layer = model.embedding_proj_layer.to(torch_dtype) | |
| if hasattr(vlm, "device"): | |
| model.embedding_proj_layer = model.embedding_proj_layer.to(vlm.device) | |
| tied = getattr(vlm, "_tied_weights_keys", None) | |
| if isinstance(tied, dict): | |
| model._tied_weights_keys = {f"vlm.{k}": f"vlm.{v}" for k, v in tied.items()} | |
| elif isinstance(tied, (list, tuple, set)): | |
| model._tied_weights_keys = [f"vlm.{k}" for k in tied] | |
| else: | |
| model._tied_weights_keys = [] | |
| return model | |
| def _load_merged( | |
| cls, | |
| path: str, | |
| torch_dtype: torch.dtype = None, | |
| device_map: str = None, | |
| attn_impl: str = None, | |
| **kwargs, | |
| ): | |
| """Load a merged ColVec1 checkpoint (dense VLM weights + embedding_proj_layer).""" | |
| from safetensors.torch import load_file | |
| # Resolve Hub repo ID to a local cached snapshot directory so all | |
| # subsequent os.path / glob operations work for both local and remote paths. | |
| if not os.path.isdir(path): | |
| from huggingface_hub import snapshot_download | |
| path = snapshot_download(path) | |
| config = ColVec1Config.from_pretrained(path) | |
| base_name = config.base_model_name_or_path | |
| if base_name is None: | |
| raise ValueError( | |
| f"Merged ColVec1 config at {path} is missing 'base_model_name_or_path'. " | |
| "This field is required to know which VLM architecture to instantiate." | |
| ) | |
| vlm_kwargs = {"trust_remote_code": True} | |
| if torch_dtype is not None: | |
| vlm_kwargs["torch_dtype"] = torch_dtype | |
| if device_map is not None: | |
| vlm_kwargs["device_map"] = device_map | |
| if attn_impl is not None: | |
| vlm_kwargs["attn_implementation"] = attn_impl | |
| vlm = AutoModelForImageTextToText.from_pretrained(base_name, **vlm_kwargs) | |
| model = cls(config) | |
| model.vlm = vlm | |
| safetensor_files = sorted(glob.glob(os.path.join(path, "model*.safetensors"))) | |
| if not safetensor_files: | |
| raise FileNotFoundError(f"No model*.safetensors files found in {path}") | |
| state_dict = {} | |
| for sf in safetensor_files: | |
| state_dict.update(load_file(sf)) | |
| model.load_state_dict(state_dict, strict=False) | |
| if torch_dtype is not None: | |
| model.embedding_proj_layer = model.embedding_proj_layer.to(torch_dtype) | |
| if hasattr(vlm, "device"): | |
| model.embedding_proj_layer = model.embedding_proj_layer.to(vlm.device) | |
| tied = getattr(vlm, "_tied_weights_keys", None) | |
| if isinstance(tied, dict): | |
| model._tied_weights_keys = {f"vlm.{k}": f"vlm.{v}" for k, v in tied.items()} | |
| elif isinstance(tied, (list, tuple, set)): | |
| model._tied_weights_keys = [f"vlm.{k}" for k in tied] | |
| else: | |
| model._tied_weights_keys = [] | |
| return model | |
| def tie_weights(self, *args, **kwargs): | |
| if self.vlm is None: | |
| # Called during post_init() before the wrapped VLM is attached. | |
| return None | |
| try: | |
| return self.vlm.tie_weights(*args, **kwargs) | |
| except TypeError: | |
| return self.vlm.tie_weights() | |
| def get_input_embeddings(self): | |
| return self.vlm.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.vlm.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.vlm.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| self.vlm.set_output_embeddings(new_embeddings) | |
| def resize_token_embeddings( | |
| self, | |
| new_num_tokens: Optional[int] = None, | |
| pad_to_multiple_of: Optional[int] = None, | |
| mean_resizing: bool = True, | |
| ) -> nn.Embedding: | |
| model_embeds = self.vlm.resize_token_embeddings( | |
| new_num_tokens=new_num_tokens, | |
| pad_to_multiple_of=pad_to_multiple_of, | |
| mean_resizing=mean_resizing, | |
| ) | |
| if hasattr(self.vlm.config, "text_config"): | |
| self.vlm.config.text_config.vocab_size = model_embeds.num_embeddings | |
| if hasattr(self.vlm.config, "vocab_size"): | |
| self.vlm.config.vocab_size = model_embeds.num_embeddings | |
| return model_embeds | |
| def device(self): | |
| return next(self.parameters()).device | |
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): | |
| if self.vlm is not None and hasattr(self.vlm, "gradient_checkpointing_enable"): | |
| self.vlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) | |
| def gradient_checkpointing_disable(self): | |
| if self.vlm is not None and hasattr(self.vlm, "gradient_checkpointing_disable"): | |
| self.vlm.gradient_checkpointing_disable() | |
| __all__ = ["ColVec1", "ColVec1PreTrainedModel"] | |