class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: str | None
"""The optional model revision."""
subfolder: str | None = None
"""The subfolder inside the model repo."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: list[str] | None = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self.local_expert_ids: set[int] | None = None
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}"
)
def _prepare_weights(
self,
model_name_or_path: str,
subfolder: str | None,
revision: str | None,
fall_back_to_pt: bool,
allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (
maybe_download_from_modelscope(model_name_or_path, revision)
or model_name_or_path
)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# First check for 'auto' format that mistral files format are present.
# This is to load mistral models with official format by default.
if load_format == "auto":
load_format = (
"mistral"
if len(
list_filtered_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=["consolidated*.safetensors"],
revision=revision,
)
)
> 0
else "hf"
)
# Some quantized models use .pt files for storing the weights.
if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"]
elif (
load_format == "safetensors"
or load_format == "fastsafetensors"
or load_format == "instanttensor"
):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "mistral":
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
subfolder=subfolder,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
if subfolder is not None:
hf_folder = os.path.join(hf_folder, subfolder)
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
cache_dir=self.load_config.download_dir,
subfolder=subfolder,
revision=revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path,
source.subfolder,
source.revision,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif self.load_config.load_format == "instanttensor":
weights_iterator = instanttensor_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
local_expert_ids=self.local_expert_ids,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(
model_name_or_path=model_config.model,
subfolder=None,
revision=model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None,
)
def _init_ep_weight_filter(self, model_config: ModelConfig) -> None:
"""Compute local expert ids for EP weight filtering.
When expert parallelism is active, each rank only needs a subset of
expert weights. By computing the set upfront we can skip non-local
expert tensors *before* reading them from disk.
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
if not (
model_config.is_moe
and parallel_config.enable_expert_parallel
and parallel_config.enable_ep_weight_filter
):
return
# When EPLB is enabled, redundant physical expert slots may map to
# logical experts that belong to other ranks in the default partition.
# The weight loader needs to see ALL logical expert weights so it can
# populate these redundant slots. Skip the filter entirely.
if parallel_config.enable_eplb:
return
num_experts = model_config.get_num_experts()
if num_experts <= 0:
return
# EP size/rank computation mirrors FusedMoEParallelConfig.make():
# ep_size = dp_size * pcp_size * tp_size (flattened)
# ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
)
dp_size = parallel_config.data_parallel_size
tp_size = parallel_config.tensor_parallel_size
pcp_size = parallel_config.prefill_context_parallel_size
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_rank = get_tensor_model_parallel_rank() if tp_size > 1 else 0
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
ep_size = dp_size * pcp_size * tp_size
ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
self.local_expert_ids = compute_local_expert_ids(
num_experts,
ep_size,
ep_rank,
placement=parallel_config.expert_placement_strategy,
)
if self.local_expert_ids is not None:
logger.info_once(
"EP weight filter: ep_size=%d, ep_rank=%d, loading %d/%d experts",
ep_size,
ep_rank,
len(self.local_expert_ids),
num_experts,
)
@instrument(span_name="Load weights")
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.15.0")
):
self.load_config.safetensors_load_strategy = "torchao"
self._init_ep_weight_filter(model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights - self.counter_before_loading_weights,
scope="local",
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)