RLHF NCCL Fsdp Ep¶
Source https://github.com/vllm-project/vllm/blob/main/examples/rl/rlhf_nccl_fsdp_ep.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
RLHF with FSDP2 training (4 GPUs) and vLLM expert-parallel inference (4 GPUs).
8-GPU layout:
Training — 4 GPUs, PyTorch FSDP2 (fully_shard)
Inference — 4 GPUs, vLLM AsyncLLMEngine with expert parallelism +
data parallelism (TP=1, DP=4, enable_expert_parallel
→ EP_SIZE = TP×DP = 4)
FSDP workers are Ray actors that form a single FSDP2 process group.
Rank 0 gathers full parameters via DTensor.full_tensor() and broadcasts
them to the vLLM inference engine through the NCCL weight-transfer API.
The inference engine uses AsyncLLMEngine which automatically spawns
DP worker processes (no manual placement group needed). Weight sync
uses pause_generation / resume_generation.
Steps:
1. Launch 4 FSDP training workers.
2. Launch AsyncLLMEngine with EP+DP (dummy weights).
3. Generate from prompts → gibberish (random weights).
4. Pause generation, transfer weights from FSDP, resume.
5. Generate from prompts → sensible output (synced weights).
Assumes a single-node cluster with 8 GPUs.
"""
import asyncio
import os
import uuid
from dataclasses import asdict
import ray
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from torch.distributed.fsdp import fully_shard
from transformers import AutoModelForCausalLM
import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
FSDP_WORLD_SIZE = 4
INFERENCE_TP_SIZE = 1
INFERENCE_DP_SIZE = 4
@ray.remote(num_gpus=1)
class FSDPTrainWorker:
"""
One FSDP2 training worker per GPU. Four of these form the FSDP group.
Rank 0 additionally handles weight transfer to the vLLM engine.
"""
def __init__(
self,
model_name: str,
rank: int,
fsdp_world_size: int,
fsdp_master_addr: str,
fsdp_master_port: int,
):
self.rank = rank
os.environ["MASTER_ADDR"] = fsdp_master_addr
os.environ["MASTER_PORT"] = str(fsdp_master_port)
dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
torch.accelerator.set_device_index(0)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
self.weight_names = [n for n, _ in model.named_parameters()]
self.weight_dtype_names = [
str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
]
self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]
for layer in model.model.layers:
fully_shard(layer)
fully_shard(model)
self.model = model
self.transfer_port = None
self.transfer_master_address = None
self.model_update_group = None
def get_rank(self):
return self.rank
# ---- weight-transfer setup (rank 0 only) ----
def setup_transfer_endpoint(self):
"""Create the NCCL rendezvous endpoint for weight transfer."""
assert self.rank == 0
self.transfer_port = get_open_port()
self.transfer_master_address = get_ip()
return self.transfer_master_address, self.transfer_port
def init_weight_transfer_group(self, transfer_world_size: int):
"""Join the weight-transfer NCCL group as rank 0 (the source)."""
assert self.rank == 0
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.transfer_master_address,
master_port=self.transfer_port,
world_size=transfer_world_size,
),
)
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes captured before FSDP wrapping."""
return self.weight_names, self.weight_dtype_names, self.weight_shapes
# ---- collective ops (ALL FSDP ranks must call concurrently) ----
def gather_and_broadcast_weights(self, packed: bool = True):
"""
All-gather full parameters and broadcast them to vLLM.
Only rank 0 performs the actual NCCL broadcast; others just
participate in the FSDP all-gather.
full_tensor() is a collective — all FSDP ranks must call it
for each parameter in the same order. Rank 0 additionally
feeds each gathered tensor to the weight-transfer engine.
"""
if self.rank == 0:
def _full_param_iter():
for name, param in self.model.named_parameters():
yield name, param.full_tensor()
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=_full_param_iter(),
trainer_args=trainer_args,
)
else:
for _, param in self.model.named_parameters():
param.full_tensor()
def create_async_engine(**kwargs):
"""Create an AsyncLLMEngine directly (no subclass needed)."""
engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
return vllm.AsyncLLMEngine(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
)
async def generate_batch(engine, prompts, sampling_params):
"""Generate completions for a batch of prompts."""
async def gen_one(prompt):
output = None
async for request_output in engine.generate(
{"prompt": prompt},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
return output
return await asyncio.gather(*[gen_one(p) for p in prompts])
async def main():
ray.init()
# Download model weights to local/shared disk once.
local_model_path = snapshot_download(MODEL_NAME)
print(f"[init] Model downloaded to {local_model_path}")
# FSDP rendezvous address (single-node)
fsdp_master_addr = get_ip()
fsdp_master_port = get_open_port()
# Launch 4 FSDP training workers.
# Ray allocates 1 GPU per worker; AsyncLLMEngine's internal DP
# placement groups will land on the remaining 4 GPUs.
fsdp_workers = [
FSDPTrainWorker.remote(
local_model_path,
rank,
FSDP_WORLD_SIZE,
fsdp_master_addr,
fsdp_master_port,
)
for rank in range(FSDP_WORLD_SIZE)
]
ray.get([w.get_rank.remote() for w in fsdp_workers])
print(f"[init] {FSDP_WORLD_SIZE} FSDP training workers ready.")
# Launch vLLM with expert parallelism + data parallelism.
# AsyncLLMEngine with data_parallel_backend="ray" creates its own
# placement groups internally — no manual placement group needed.
print("[engine] Creating AsyncLLMEngine...")
engine = create_async_engine(
model=local_model_path,
enforce_eager=True,
tensor_parallel_size=INFERENCE_TP_SIZE,
data_parallel_size=INFERENCE_DP_SIZE,
enable_expert_parallel=True,
distributed_executor_backend="ray",
data_parallel_backend="ray",
weight_transfer_config=WeightTransferConfig(backend="nccl"),
load_format="dummy",
gpu_memory_utilization=0.7,
)
print("[engine] AsyncLLMEngine created.")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Generate with dummy weights — expect gibberish.
print("[generate] Starting generation with dummy weights...")
outputs = await generate_batch(engine, prompts, sampling_params)
print("[generate] Generation complete.")
print("-" * 60)
print("BEFORE weight sync (dummy weights):")
print("-" * 60)
for output in outputs:
print(f"Prompt: {output.prompt!r}")
print(f"Generated: {output.outputs[0].text!r}")
print("-" * 60)
# --- Weight-transfer setup ---
print("[transfer] Setting up weight-transfer endpoint...")
transfer_addr, transfer_port = ray.get(
fsdp_workers[0].setup_transfer_endpoint.remote()
)
print(f"[transfer] Endpoint ready at {transfer_addr}:{transfer_port}")
transfer_world_size = INFERENCE_TP_SIZE * INFERENCE_DP_SIZE + 1
print(
f"[transfer] World size: {transfer_world_size} "
f"(1 trainer + {INFERENCE_TP_SIZE * INFERENCE_DP_SIZE} vLLM workers)"
)
print("[transfer] Initializing NCCL groups...")
train_handle = fsdp_workers[0].init_weight_transfer_group.remote(
transfer_world_size
)
await engine.init_weight_transfer_engine(
WeightTransferInitRequest(
init_info=asdict(
NCCLWeightTransferInitInfo(
master_address=transfer_addr,
master_port=transfer_port,
rank_offset=1,
world_size=transfer_world_size,
)
)
)
)
ray.get(train_handle)
print("[transfer] NCCL groups initialized.")
# --- Pause, transfer weights, resume ---
print("[sync] Pausing generation...")
await engine.pause_generation(mode="abort")
print("[sync] Generation paused.")
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
print(f"[sync] Got metadata for {len(names)} parameters.")
print("[sync] Broadcasting weights from FSDP → vLLM...")
broadcast_handles = [
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
]
await engine.update_weights(
WeightTransferUpdateRequest(
update_info=asdict(
NCCLWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
)
ray.get(broadcast_handles)
print("[sync] Weight broadcast complete.")
print("[sync] Resuming generation...")
await engine.resume_generation()
print("[sync] Generation resumed.")
# Generate with synced weights — expect sensible output.
print("[generate] Starting generation with synced weights...")
outputs_updated = await generate_batch(engine, prompts, sampling_params)
print("[generate] Generation complete.")
print("-" * 60)
print("AFTER weight sync (real weights):")
print("-" * 60)
for output in outputs_updated:
print(f"Prompt: {output.prompt!r}")
print(f"Generated: {output.outputs[0].text!r}")
print("-" * 60)
if __name__ == "__main__":
asyncio.run(main())