Skip to content

Base Class and Custom Engines

The weight transfer system is built on an abstract base class that defines the contract between vLLM's worker infrastructure and the transport backend. You can implement custom backends by subclassing WeightTransferEngine and registering them with the WeightTransferEngineFactory.

WeightTransferEngine

The WeightTransferEngine is a generic abstract class parameterized by two dataclass types:

  • TInitInfo (extends WeightTransferInitInfo): Backend-specific initialization parameters.
  • TUpdateInfo (extends WeightTransferUpdateInfo): Backend-specific weight update metadata.

Abstract Methods

Subclasses must implement these four methods:

Method Side Description
init_transfer_engine(init_info) Inference Initialize the communication channel on each inference worker
receive_weights(update_info, load_weights) Inference Receive weights and call load_weights incrementally
shutdown() Inference Clean up resources
trainer_send_weights(iterator, trainer_args) Trainer Static method to send weights from the trainer process

Request Classes

The API-level request classes provide backend-agnostic serialization using plain dictionaries. The engine's parse_init_info and parse_update_info methods convert these dictionaries into typed dataclasses.

from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)

# Init request (dict is converted to backend-specific TInitInfo)
init_request = WeightTransferInitRequest(
    init_info={"master_address": "10.0.0.1", "master_port": 29500, ...}
)

# Update request (dict is converted to backend-specific TUpdateInfo)
update_request = WeightTransferUpdateRequest(
    update_info={"names": [...], "dtype_names": [...], "shapes": [...]}
)

WeightTransferUpdateInfo

The base WeightTransferUpdateInfo includes an is_checkpoint_format flag:

@dataclass
class WeightTransferUpdateInfo(ABC):
    is_checkpoint_format: bool = True

When is_checkpoint_format=True (the default), vLLM applies layerwise weight processing (repacking, renaming, etc.) on the received weights before loading them. Set to False if the trainer has already converted weights to the kernel format expected by the model.

Implementing a Custom Engine

To create a custom weight transfer backend:

1. Define Info Dataclasses

from dataclasses import dataclass
from vllm.distributed.weight_transfer.base import (
    WeightTransferEngine,
    WeightTransferInitInfo,
    WeightTransferUpdateInfo,
)

@dataclass
class MyInitInfo(WeightTransferInitInfo):
    endpoint: str
    token: str

@dataclass
class MyUpdateInfo(WeightTransferUpdateInfo):
    names: list[str]
    dtype_names: list[str]
    shapes: list[list[int]]
    # Add custom fields as needed

2. Implement the Engine

from collections.abc import Callable, Iterator
from typing import Any
import torch

class MyWeightTransferEngine(WeightTransferEngine[MyInitInfo, MyUpdateInfo]):
    init_info_cls = MyInitInfo
    update_info_cls = MyUpdateInfo

    def init_transfer_engine(self, init_info: MyInitInfo) -> None:
        # Set up connection to trainer using init_info.endpoint, etc.
        ...

    def receive_weights(
        self,
        update_info: MyUpdateInfo,
        load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
    ) -> None:
        # Receive each weight and call load_weights incrementally
        for name, dtype_name, shape in zip(
            update_info.names, update_info.dtype_names, update_info.shapes
        ):
            dtype = getattr(torch, dtype_name)
            weight = self._fetch_weight(name, shape, dtype)
            load_weights([(name, weight)])

    def shutdown(self) -> None:
        # Clean up resources
        ...

    @staticmethod
    def trainer_send_weights(
        iterator: Iterator[tuple[str, torch.Tensor]],
        trainer_args: dict[str, Any],
    ) -> None:
        # Send weights from the trainer process
        for name, tensor in iterator:
            # Send tensor via custom transport
            ...

Important

The load_weights callable passed to receive_weights should be called incrementally (one or a few weights at a time) rather than accumulating all weights first. This avoids GPU out-of-memory errors with large models.

3. Register with the Factory

from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory

# Option 1: Lazy loading (recommended for built-in engines)
WeightTransferEngineFactory.register_engine(
    "my_backend",
    "my_package.my_module",
    "MyWeightTransferEngine",
)

# Option 2: Direct class registration
WeightTransferEngineFactory.register_engine(
    "my_backend",
    MyWeightTransferEngine,
)

Once registered, users can select your backend via WeightTransferConfig(backend="my_backend").

WeightTransferEngineFactory

The factory uses a registry pattern with lazy loading. Built-in engines (nccl and ipc) are registered at import time but their modules are only loaded when the backend is actually requested. This avoids importing heavy dependencies (like NCCL communicators) when they aren't needed.

from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory

# Create an engine from config
engine = WeightTransferEngineFactory.create_engine(
    config=weight_transfer_config,
    parallel_config=parallel_config,
)