Skip to content

vllm.model_executor.models.colqwen3_5

ColQwen3.5 late interaction model for multi-modal retrieval and reranking.

ColQwen3.5 extends Qwen3.5 with a ColBERT-style late interaction head, producing per-token embeddings for both text and image inputs. It uses MaxSim scoring for retrieval/reranking tasks.

This model supports the "token_embed" pooling task and is designed for multi-vector retrieval of documents containing both text and images.

Reference: https://arxiv.org/abs/2407.01449 (ColPali) Based on: Qwen3.5 backbone with custom text projection

Target models: - athrael-soju/colqwen3.5-4.5B-v3

ColQwen3_5Model

Bases: Qwen3_5ForConditionalGeneration, SupportsLateInteraction

ColQwen3.5 late interaction model for multi-modal retrieval/reranking.

This model extends Qwen3_5ForConditionalGeneration with a ColBERT-style linear projection layer for per-token embeddings. It supports: - "token_embed" task: Per-token embeddings for late interaction scoring

The model produces per-token embeddings by: 1. Running the Qwen3.5 backbone (vision + language) to get hidden states 2. Projecting hidden states through a linear layer (hidden_size -> embed_dim) 3. L2 normalization is handled by the pooler via PoolerNormalize

Attributes:

Name Type Description
custom_text_proj

Linear projection from hidden_size to embed_dim

Source code in vllm/model_executor/models/colqwen3_5.py
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3VLMultiModalProcessor,
    info=ColQwen3_5ProcessingInfo,
    dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class ColQwen3_5Model(
    Qwen3_5ForConditionalGeneration,
    SupportsLateInteraction,
):
    """ColQwen3.5 late interaction model for multi-modal retrieval/reranking.

    This model extends Qwen3_5ForConditionalGeneration with a ColBERT-style
    linear projection layer for per-token embeddings. It supports:
    - "token_embed" task: Per-token embeddings for late interaction scoring

    The model produces per-token embeddings by:
    1. Running the Qwen3.5 backbone (vision + language) to get hidden states
    2. Projecting hidden states through a linear layer (hidden_size -> embed_dim)
    3. L2 normalization is handled by the pooler via PoolerNormalize

    Attributes:
        custom_text_proj: Linear projection from hidden_size to embed_dim
    """

    # Mark this as a pooling model so vLLM routes to pooler path
    is_pooling_model = True

    # Override hf_to_vllm_mapper to handle ColQwen3.5 weight naming.
    # ColPali saves weights as "language_model.*" but vLLM's
    # Qwen3_5ForCausalLM has them under "language_model.model.*".
    # Visual weights ("visual.*") already match the vLLM module path.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.": "language_model.model.",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        config = vllm_config.model_config.hf_config
        head_dtype = vllm_config.model_config.head_dtype

        hidden_size = getattr(config, "hidden_size", None)
        if hidden_size is None and hasattr(config, "text_config"):
            hidden_size = config.text_config.hidden_size
        if hidden_size is None:
            raise ValueError(
                "Unable to determine text hidden size from config. "
                "Expected 'hidden_size' or 'text_config.hidden_size'."
            )

        # (ColPali: dim, projection_dim, colbert_dim)
        self.embed_dim: int = (
            getattr(config, "embed_dim", None)
            or getattr(config, "dims", None)
            or getattr(config, "dim", None)
            or getattr(config, "projection_dim", None)
            or getattr(config, "colbert_dim", None)
            or 128  # default from reference implementation
        )

        self.custom_text_proj = nn.Linear(
            hidden_size,
            self.embed_dim,
            bias=False,
            dtype=head_dtype,
        )

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = pooler_for_token_embed(
            pooler_config,
            projector=None,
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors=None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor:
        """Run forward pass producing per-token embeddings."""
        hidden_states = super().forward(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )

        if not isinstance(hidden_states, torch.Tensor):
            return hidden_states  # type: ignore

        proj_dtype = self.custom_text_proj.weight.dtype
        if hidden_states.dtype != proj_dtype:
            hidden_states = hidden_states.to(proj_dtype)

        # Project to embedding dimension (normalization handled by pooler)
        return self.custom_text_proj(hidden_states)

    # Names used for the projection layer across different ColQwen3.5 variants
    _PROJ_LAYER_NAMES = {
        "custom_text_proj",  # ColPali naming
        "embedding_proj_layer",  # Alternative naming
    }

    def _is_proj_weight(self, name: str) -> bool:
        """Check if a weight name belongs to the projection layer."""
        return any(proj_name in name for proj_name in self._PROJ_LAYER_NAMES)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Load weights with special handling for projection layer."""
        weights_list = list(weights)
        proj_weights: list[tuple[str, torch.Tensor]] = []
        model_weights: list[tuple[str, torch.Tensor]] = []

        for name, weight in weights_list:
            if self._is_proj_weight(name):
                proj_weights.append((name, weight))
            else:
                model_weights.append((name, weight))

        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        loaded = loader.load_weights(model_weights, mapper=self.hf_to_vllm_mapper)

        for name, weight in proj_weights:
            param_name = name.split(".")[-1]
            param = getattr(self.custom_text_proj, param_name, None)
            if param is not None:
                weight = weight.to(device=param.device, dtype=param.dtype)
                default_weight_loader(param, weight)
                loaded.add(f"custom_text_proj.{param_name}")

        return loaded

_is_proj_weight

_is_proj_weight(name: str) -> bool

Check if a weight name belongs to the projection layer.

Source code in vllm/model_executor/models/colqwen3_5.py
def _is_proj_weight(self, name: str) -> bool:
    """Check if a weight name belongs to the projection layer."""
    return any(proj_name in name for proj_name in self._PROJ_LAYER_NAMES)

forward

forward(
    input_ids: Tensor | None,
    positions: Tensor,
    intermediate_tensors=None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
) -> Tensor

Run forward pass producing per-token embeddings.

Source code in vllm/model_executor/models/colqwen3_5.py
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors=None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor:
    """Run forward pass producing per-token embeddings."""
    hidden_states = super().forward(
        input_ids=input_ids,
        positions=positions,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
        **kwargs,
    )

    if not isinstance(hidden_states, torch.Tensor):
        return hidden_states  # type: ignore

    proj_dtype = self.custom_text_proj.weight.dtype
    if hidden_states.dtype != proj_dtype:
        hidden_states = hidden_states.to(proj_dtype)

    # Project to embedding dimension (normalization handled by pooler)
    return self.custom_text_proj(hidden_states)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights with special handling for projection layer.

Source code in vllm/model_executor/models/colqwen3_5.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    """Load weights with special handling for projection layer."""
    weights_list = list(weights)
    proj_weights: list[tuple[str, torch.Tensor]] = []
    model_weights: list[tuple[str, torch.Tensor]] = []

    for name, weight in weights_list:
        if self._is_proj_weight(name):
            proj_weights.append((name, weight))
        else:
            model_weights.append((name, weight))

    loader = AutoWeightsLoader(
        self,
        skip_prefixes=["mtp."],
    )
    loaded = loader.load_weights(model_weights, mapper=self.hf_to_vllm_mapper)

    for name, weight in proj_weights:
        param_name = name.split(".")[-1]
        param = getattr(self.custom_text_proj, param_name, None)
        if param is not None:
            weight = weight.to(device=param.device, dtype=param.dtype)
            default_weight_loader(param, weight)
            loaded.add(f"custom_text_proj.{param_name}")

    return loaded

ColQwen3_5ProcessingInfo

Bases: Qwen3_5ProcessingInfo

Processing info for ColQwen3.5 models.

ColQwen3.5 models use custom HuggingFace processors (e.g. ColQwen3_5Processor) that are incompatible with vLLM's Qwen3VLMultiModalProcessor. We override get_hf_config() and get_hf_processor() to skip the strict type check and force the standard Qwen3VLProcessor.

Source code in vllm/model_executor/models/colqwen3_5.py
class ColQwen3_5ProcessingInfo(Qwen3_5ProcessingInfo):
    """Processing info for ColQwen3.5 models.

    ColQwen3.5 models use custom HuggingFace processors (e.g.
    ColQwen3_5Processor) that are incompatible with vLLM's
    Qwen3VLMultiModalProcessor. We override get_hf_config() and
    get_hf_processor() to skip the strict type check and force the
    standard Qwen3VLProcessor.
    """

    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
        return self.ctx.get_hf_processor(
            Qwen3VLProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )

    @property
    def _supports_video(self) -> bool:
        """Check if the HF processor supports video inputs."""
        return hasattr(self.get_hf_processor(), "video_processor")

    def get_video_processor(self, **kwargs: object):
        if not self._supports_video:
            raise AttributeError(
                f"The processor for {self.ctx.model_config.model} does not "
                "support video inputs (no video_processor attribute)."
            )
        return self.get_hf_processor(**kwargs).video_processor  # type: ignore[attr-defined]

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        limits: dict[str, int | None] = {"image": None}
        if self._supports_video:
            limits["video"] = None
        return limits

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        result: dict[str, int] = {"image": max_image_tokens}
        if self._supports_video:
            max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
            result["video"] = max_video_tokens
        return result

    def get_data_parser(self):
        hf_config = self.get_hf_config()
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        return Qwen2VLMultiModalDataParser(
            spatial_merge_size,
            video_needs_metadata=self._supports_video,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

_supports_video property

_supports_video: bool

Check if the HF processor supports video inputs.