Skip to content

RLHF NCCL Fsdp Ep

Source https://github.com/vllm-project/vllm/blob/main/examples/rl/rlhf_nccl_fsdp_ep.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
RLHF with FSDP2 training (4 GPUs) and vLLM expert-parallel inference (4 GPUs).

8-GPU layout:
  Training  — 4 GPUs, PyTorch FSDP2 (fully_shard)
  Inference — 4 GPUs, vLLM AsyncLLMEngine with expert parallelism +
              data parallelism (TP=1, DP=4, enable_expert_parallel
              → EP_SIZE = TP×DP = 4)

FSDP workers are Ray actors that form a single FSDP2 process group.
Rank 0 gathers full parameters via DTensor.full_tensor() and broadcasts
them to the vLLM inference engine through the NCCL weight-transfer API.

The inference engine uses AsyncLLMEngine which automatically spawns
DP worker processes (no manual placement group needed).  Weight sync
uses pause_generation / resume_generation.

Steps:
  1. Launch 4 FSDP training workers.
  2. Launch AsyncLLMEngine with EP+DP (dummy weights).
  3. Generate from prompts → gibberish (random weights).
  4. Pause generation, transfer weights from FSDP, resume.
  5. Generate from prompts → sensible output (synced weights).

Assumes a single-node cluster with 8 GPUs.
"""

import asyncio
import os
import uuid
from dataclasses import asdict

import ray
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from torch.distributed.fsdp import fully_shard
from transformers import AutoModelForCausalLM

import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLTrainerSendWeightsArgs,
    NCCLWeightTransferEngine,
    NCCLWeightTransferInitInfo,
    NCCLWeightTransferUpdateInfo,
)
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor

MODEL_NAME = "Qwen/Qwen3-30B-A3B"

FSDP_WORLD_SIZE = 4
INFERENCE_TP_SIZE = 1
INFERENCE_DP_SIZE = 4


@ray.remote(num_gpus=1)
class FSDPTrainWorker:
    """
    One FSDP2 training worker per GPU.  Four of these form the FSDP group.
    Rank 0 additionally handles weight transfer to the vLLM engine.
    """

    def __init__(
        self,
        model_name: str,
        rank: int,
        fsdp_world_size: int,
        fsdp_master_addr: str,
        fsdp_master_port: int,
    ):
        self.rank = rank

        os.environ["MASTER_ADDR"] = fsdp_master_addr
        os.environ["MASTER_PORT"] = str(fsdp_master_port)

        dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
        torch.accelerator.set_device_index(0)

        model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.bfloat16
        )

        self.weight_names = [n for n, _ in model.named_parameters()]
        self.weight_dtype_names = [
            str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
        ]
        self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]

        for layer in model.model.layers:
            fully_shard(layer)
        fully_shard(model)

        self.model = model

        self.transfer_port = None
        self.transfer_master_address = None
        self.model_update_group = None

    def get_rank(self):
        return self.rank

    # ---- weight-transfer setup (rank 0 only) ----

    def setup_transfer_endpoint(self):
        """Create the NCCL rendezvous endpoint for weight transfer."""
        assert self.rank == 0
        self.transfer_port = get_open_port()
        self.transfer_master_address = get_ip()
        return self.transfer_master_address, self.transfer_port

    def init_weight_transfer_group(self, transfer_world_size: int):
        """Join the weight-transfer NCCL group as rank 0 (the source)."""
        assert self.rank == 0
        self.model_update_group = NCCLWeightTransferEngine.trainer_init(
            dict(
                master_address=self.transfer_master_address,
                master_port=self.transfer_port,
                world_size=transfer_world_size,
            ),
        )

    def get_weight_metadata(self):
        """Return weight names, dtypes, and shapes captured before FSDP wrapping."""
        return self.weight_names, self.weight_dtype_names, self.weight_shapes

    # ---- collective ops (ALL FSDP ranks must call concurrently) ----

    def gather_and_broadcast_weights(self, packed: bool = True):
        """
        All-gather full parameters and broadcast them to vLLM.
        Only rank 0 performs the actual NCCL broadcast; others just
        participate in the FSDP all-gather.

        full_tensor() is a collective — all FSDP ranks must call it
        for each parameter in the same order.  Rank 0 additionally
        feeds each gathered tensor to the weight-transfer engine.
        """
        if self.rank == 0:

            def _full_param_iter():
                for name, param in self.model.named_parameters():
                    yield name, param.full_tensor()

            trainer_args = NCCLTrainerSendWeightsArgs(
                group=self.model_update_group,
                packed=packed,
            )
            NCCLWeightTransferEngine.trainer_send_weights(
                iterator=_full_param_iter(),
                trainer_args=trainer_args,
            )
        else:
            for _, param in self.model.named_parameters():
                param.full_tensor()


def create_async_engine(**kwargs):
    """Create an AsyncLLMEngine directly (no subclass needed)."""
    engine_args = vllm.AsyncEngineArgs(**kwargs)
    vllm_config = engine_args.create_engine_config()
    executor_class = Executor.get_class(vllm_config)
    return vllm.AsyncLLMEngine(
        vllm_config=vllm_config,
        executor_class=executor_class,
        log_requests=engine_args.enable_log_requests,
        log_stats=not engine_args.disable_log_stats,
    )


async def generate_batch(engine, prompts, sampling_params):
    """Generate completions for a batch of prompts."""

    async def gen_one(prompt):
        output = None
        async for request_output in engine.generate(
            {"prompt": prompt},
            sampling_params,
            request_id=str(uuid.uuid4()),
        ):
            output = request_output
        return output

    return await asyncio.gather(*[gen_one(p) for p in prompts])


async def main():
    ray.init()

    # Download model weights to local/shared disk once.
    local_model_path = snapshot_download(MODEL_NAME)
    print(f"[init] Model downloaded to {local_model_path}")

    # FSDP rendezvous address (single-node)
    fsdp_master_addr = get_ip()
    fsdp_master_port = get_open_port()

    # Launch 4 FSDP training workers.
    # Ray allocates 1 GPU per worker; AsyncLLMEngine's internal DP
    # placement groups will land on the remaining 4 GPUs.
    fsdp_workers = [
        FSDPTrainWorker.remote(
            local_model_path,
            rank,
            FSDP_WORLD_SIZE,
            fsdp_master_addr,
            fsdp_master_port,
        )
        for rank in range(FSDP_WORLD_SIZE)
    ]
    ray.get([w.get_rank.remote() for w in fsdp_workers])
    print(f"[init] {FSDP_WORLD_SIZE} FSDP training workers ready.")

    # Launch vLLM with expert parallelism + data parallelism.
    # AsyncLLMEngine with data_parallel_backend="ray" creates its own
    # placement groups internally — no manual placement group needed.
    print("[engine] Creating AsyncLLMEngine...")
    engine = create_async_engine(
        model=local_model_path,
        enforce_eager=True,
        tensor_parallel_size=INFERENCE_TP_SIZE,
        data_parallel_size=INFERENCE_DP_SIZE,
        enable_expert_parallel=True,
        distributed_executor_backend="ray",
        data_parallel_backend="ray",
        weight_transfer_config=WeightTransferConfig(backend="nccl"),
        load_format="dummy",
        gpu_memory_utilization=0.7,
    )
    print("[engine] AsyncLLMEngine created.")

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)

    # Generate with dummy weights — expect gibberish.
    print("[generate] Starting generation with dummy weights...")
    outputs = await generate_batch(engine, prompts, sampling_params)
    print("[generate] Generation complete.")

    print("-" * 60)
    print("BEFORE weight sync (dummy weights):")
    print("-" * 60)
    for output in outputs:
        print(f"Prompt: {output.prompt!r}")
        print(f"Generated: {output.outputs[0].text!r}")
        print("-" * 60)

    # --- Weight-transfer setup ---
    print("[transfer] Setting up weight-transfer endpoint...")
    transfer_addr, transfer_port = ray.get(
        fsdp_workers[0].setup_transfer_endpoint.remote()
    )
    print(f"[transfer] Endpoint ready at {transfer_addr}:{transfer_port}")

    transfer_world_size = INFERENCE_TP_SIZE * INFERENCE_DP_SIZE + 1
    print(
        f"[transfer] World size: {transfer_world_size} "
        f"(1 trainer + {INFERENCE_TP_SIZE * INFERENCE_DP_SIZE} vLLM workers)"
    )

    print("[transfer] Initializing NCCL groups...")
    train_handle = fsdp_workers[0].init_weight_transfer_group.remote(
        transfer_world_size
    )
    await engine.init_weight_transfer_engine(
        WeightTransferInitRequest(
            init_info=asdict(
                NCCLWeightTransferInitInfo(
                    master_address=transfer_addr,
                    master_port=transfer_port,
                    rank_offset=1,
                    world_size=transfer_world_size,
                )
            )
        )
    )
    ray.get(train_handle)
    print("[transfer] NCCL groups initialized.")

    # --- Pause, transfer weights, resume ---
    print("[sync] Pausing generation...")
    await engine.pause_generation(mode="abort")
    print("[sync] Generation paused.")

    names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
    print(f"[sync] Got metadata for {len(names)} parameters.")

    print("[sync] Broadcasting weights from FSDP → vLLM...")
    broadcast_handles = [
        w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
    ]
    await engine.update_weights(
        WeightTransferUpdateRequest(
            update_info=asdict(
                NCCLWeightTransferUpdateInfo(
                    names=names,
                    dtype_names=dtype_names,
                    shapes=shapes,
                    packed=True,
                )
            )
        )
    )
    ray.get(broadcast_handles)
    print("[sync] Weight broadcast complete.")

    print("[sync] Resuming generation...")
    await engine.resume_generation()
    print("[sync] Generation resumed.")

    # Generate with synced weights — expect sensible output.
    print("[generate] Starting generation with synced weights...")
    outputs_updated = await generate_batch(engine, prompts, sampling_params)
    print("[generate] Generation complete.")

    print("-" * 60)
    print("AFTER weight sync (real weights):")
    print("-" * 60)
    for output in outputs_updated:
        print(f"Prompt: {output.prompt!r}")
        print(f"Generated: {output.outputs[0].text!r}")
        print("-" * 60)


if __name__ == "__main__":
    asyncio.run(main())