Skip to content

vllm.kernels.helion.register

vLLM Helion kernel registration with pre-tuned config selection.

This module leverages Helion's internal config selection infrastructure to use pre-tuned configs instead of runtime autotuning.

How Helion Normally Works

For each kernel invocation, Helion: 1. Computes a cache key from input arguments 2. Looks up the key in its internal compilation cache 3. On cache miss, runs autotuning to find the best config 4. Compiles and caches the kernel with that config

How We Override It

We override two Helion hooks to use pre-tuned configs:

  1. key: We provide a key function (derived from config_picker) that computes cache keys matching our pre-tuned config keys. This ensures Helion's internal cache uses keys that correspond to configs we've prepared.

  2. autotuner_fn: We provide PresetConfigSearch which, instead of autotuning, simply returns the pre-tuned config for the computed key. On cache miss, Helion calls our autotuner which returns the author-prepared config.

Both hooks use the same config_picker logic to ensure the cache key computed by key matches the config returned by the autotuner.

Key Classes

  • HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
  • ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
  • PresetConfigSearch: Custom autotuner that returns pre-tuned configs

ConfiguredHelionKernel

A configured Helion kernel bound to a specific platform.

Source code in vllm/kernels/helion/register.py
class ConfiguredHelionKernel:
    """A configured Helion kernel bound to a specific platform."""

    def __init__(
        self,
        op_name: str,
        config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None,
        raw_kernel_func: Callable,
        helion_settings: "helion.Settings | None" = None,
    ):
        self.op_name = op_name
        self.config_picker = config_picker
        self.raw_kernel_func = raw_kernel_func
        self.helion_settings = helion_settings
        self._decorated_kernel = self._create_decorated_kernel()

    def __call__(self, *args, **kwargs):
        return self._decorated_kernel(*args, **kwargs)

    def _create_key_computer(self):
        """
        Create a key computer function derived from the config picker.

        The returned function receives kernel arguments unpacked (*args) to match
        Helion's key signature (called as self._key_fn(*args)).
        """
        if self.config_picker is None:
            raise RuntimeError(
                f"No config picker registered for kernel '{self.op_name}'. "
                f"A config_picker must be provided to register_kernel()."
            )

        # After None check, config_picker is guaranteed to be non-None
        assert self.config_picker is not None

        def key_computer(*args):
            config_keys = list(self.configs.keys())
            # Cast is safe because we checked for None above
            config_picker = cast(
                Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
            )
            selected_key = config_picker(args, config_keys)
            if selected_key:
                return selected_key
            return "default" if "default" in self.configs else None

        return key_computer

    def _create_config_selector(self, key_computer):
        def config_selector(args):
            # args is a tuple; key_computer expects unpacked args
            selected_config_key = key_computer(*args)

            if selected_config_key is None:
                raise ValueError(
                    f"Config picker returned None for kernel '{self.op_name}' "
                    f"with available config keys: {list(self.configs.keys())}"
                )

            if selected_config_key not in self.configs:
                raise ValueError(
                    f"Config picker returned invalid config key "
                    f"'{selected_config_key}' for kernel '{self.op_name}'. "
                    f"Available keys: {list(self.configs.keys())}"
                )

            return self.configs[selected_config_key]

        return config_selector

    def _load_platform_configs(self) -> None:
        from vllm.kernels.helion.config_manager import ConfigManager
        from vllm.kernels.helion.utils import get_canonical_gpu_name

        self.platform = get_canonical_gpu_name()
        config_manager = ConfigManager()
        self.configs = config_manager.get_platform_configs(self.op_name, self.platform)

        if not self.configs:
            raise ValueError(
                f"No configs available for kernel '{self.op_name}' "
                f"on platform '{self.platform}'"
            )

    def _create_decorated_kernel(self) -> Callable[..., Any]:
        self._load_platform_configs()

        key_computer = self._create_key_computer()
        config_selector = self._create_config_selector(key_computer)

        extra_kwargs = {
            "autotuner_fn": lambda _, args: PresetConfigSearch(args, config_selector),
            "key": key_computer,
        }

        logger.debug(
            "Creating decorated kernel %s with custom autotuner on platform %s",
            self.op_name,
            self.platform,
        )
        return create_helion_decorated_kernel(
            self.raw_kernel_func, self.helion_settings, extra_kwargs
        )

_create_key_computer

_create_key_computer()

Create a key computer function derived from the config picker.

The returned function receives kernel arguments unpacked (args) to match Helion's key signature (called as self._key_fn(args)).

Source code in vllm/kernels/helion/register.py
def _create_key_computer(self):
    """
    Create a key computer function derived from the config picker.

    The returned function receives kernel arguments unpacked (*args) to match
    Helion's key signature (called as self._key_fn(*args)).
    """
    if self.config_picker is None:
        raise RuntimeError(
            f"No config picker registered for kernel '{self.op_name}'. "
            f"A config_picker must be provided to register_kernel()."
        )

    # After None check, config_picker is guaranteed to be non-None
    assert self.config_picker is not None

    def key_computer(*args):
        config_keys = list(self.configs.keys())
        # Cast is safe because we checked for None above
        config_picker = cast(
            Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
        )
        selected_key = config_picker(args, config_keys)
        if selected_key:
            return selected_key
        return "default" if "default" in self.configs else None

    return key_computer

HelionKernelWrapper

Wrapper for Helion kernels with pre-tuned config selection and HOP support.

Source code in vllm/kernels/helion/register.py
class HelionKernelWrapper:
    """Wrapper for Helion kernels with pre-tuned config selection and HOP support."""

    def __init__(
        self,
        raw_kernel_func: Callable,
        op_name: str,
        fake_impl: Callable,
        config_picker: Callable[[tuple[Any, ...], list[str]], str | None],
        helion_settings: "helion.Settings | None" = None,
        input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None,
    ):
        # Validate helion_settings doesn't conflict with our custom autotuner
        validate_helion_settings(helion_settings, op_name)

        self.raw_kernel_func = raw_kernel_func
        self.op_name = op_name
        self._fake_impl = fake_impl
        self.helion_settings = helion_settings
        self._config_picker = config_picker
        self._input_generator = input_generator
        self._configured_kernel: ConfiguredHelionKernel | None = None
        # TODO(@gmagogsfm): Remove this disable flag once integrated with vLLM IR,
        # which handles op enablement/disablement.
        self._disabled = False
        self._disabled_reason: str | None = None

        try:
            if not _HOP_AVAILABLE:
                self._get_or_register_custom_op()
            else:
                self.get_configured_op()
        except ValueError as e:
            self._disabled = True
            self._disabled_reason = str(e)
            logger.warning(
                "Helion kernel '%s' is disabled: %s",
                op_name,
                self._disabled_reason,
            )

    def __call__(self, *args, **kwargs):
        if self._disabled:
            raise RuntimeError(
                f"Helion kernel '{self.op_name}' is disabled: {self._disabled_reason}"
            )
        if not _HOP_AVAILABLE:
            op = getattr(torch.ops.vllm_helion, self.op_name)
            return op(*args, **kwargs)
        assert self._configured_kernel is not None, (
            f"Kernel '{self.op_name}' was not initialized. "
            "Please open an issue on GitHub."
        )
        if get_proxy_mode() is not None:
            return self._call_via_hop(args, kwargs)
        return self._configured_kernel(*args, **kwargs)

    def _call_via_hop(
        self,
        args: tuple[Any, ...],
        kwargs: dict[str, Any],
    ) -> Any:
        kernel = self.get_configured_op()._decorated_kernel
        kernel_idx = helion_kernel_side_table.add_kernel(kernel)

        constant_args, tensor_args = self._partition_args(kernel, args, kwargs)

        all_named = {**constant_args, **tensor_args}
        full_args = tuple(
            all_named.get(n, p.default)
            for n, p in kernel.signature.parameters.items()  # type: ignore[attr-defined]
            if n in all_named or p.default is not p.empty
        )

        with disable_proxy_modes_tracing():
            output_spec = infer_output_spec(kernel, full_args)

        hop_result = helion_kernel_wrapper_mutation(
            kernel_idx=kernel_idx,
            constant_args=constant_args,
            tensor_args=tensor_args,
            output_spec=output_spec,
        )

        tree_spec_str = output_spec.get("tree_spec_str")
        if tree_spec_str is None:
            return None
        tree_spec = pytree.treespec_loads(tree_spec_str)

        hop_iter = iter(hop_result)
        reconstructed = []
        for spec in output_spec["leaf_specs"]:
            is_constant_scalar = spec["type"] == "scalar" and not isinstance(
                spec.get("scalar_value"), torch.SymInt
            )
            if is_constant_scalar:
                reconstructed.append(spec["scalar_value"])
            else:
                reconstructed.append(next(hop_iter))
        return pytree.tree_unflatten(reconstructed, tree_spec)

    @staticmethod
    def _partition_args(
        kernel: Any,
        args: tuple[Any, ...],
        kwargs: dict[str, Any],
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        constant_args: dict[str, Any] = {}
        tensor_args: dict[str, Any] = {}
        params = list(kernel.signature.parameters.keys())
        for i, val in enumerate(args):
            name = params[i]
            if isinstance(val, torch.Tensor):
                tensor_args[name] = val
            else:
                constant_args[name] = val
        for name, val in kwargs.items():
            if isinstance(val, torch.Tensor):
                tensor_args[name] = val
            else:
                constant_args[name] = val
        return constant_args, tensor_args

    def get_inputs(self) -> dict[str, tuple[Any, ...]]:
        if self._input_generator is None:
            raise NotImplementedError(
                f"No input generator registered for kernel '{self.op_name}'. "
                f"Use register_kernel(..., input_generator=...) to register one."
            )
        return self._input_generator()

    def run_autotune(
        self,
        inputs: tuple[Any, ...],
        autotune_effort: str = "quick",
    ) -> Config:
        """Run autotuning for a single input configuration."""
        extra_kwargs = {
            "autotune_effort": autotune_effort,
            "autotune_ignore_errors": True,
        }
        autotune_kernel = create_helion_decorated_kernel(
            self.raw_kernel_func, self.helion_settings, extra_kwargs
        )
        return autotune_kernel.autotune(inputs)

    def get_configured_op(self) -> ConfiguredHelionKernel:
        if self._disabled:
            raise RuntimeError(
                f"Helion kernel '{self.op_name}' is disabled: {self._disabled_reason}"
            )
        if self._configured_kernel is None:
            self._configured_kernel = ConfiguredHelionKernel(
                op_name=self.op_name,
                config_picker=self._config_picker,
                raw_kernel_func=self.raw_kernel_func,
                helion_settings=self.helion_settings,
            )
        return self._configured_kernel

    def _get_or_register_custom_op(self) -> Any:
        if hasattr(torch.ops.vllm_helion, self.op_name):
            return getattr(torch.ops.vllm_helion, self.op_name)

        configured_kernel = self.get_configured_op()

        logger.info("Registering op: vllm_helion::%s", self.op_name)
        direct_register_custom_op(
            op_name=self.op_name,
            op_func=configured_kernel._decorated_kernel,
            mutates_args=None,
            fake_impl=self._fake_impl,
            target_lib=vllm_helion_lib,
        )
        return getattr(torch.ops.vllm_helion, self.op_name)

run_autotune

run_autotune(
    inputs: tuple[Any, ...], autotune_effort: str = "quick"
) -> Config

Run autotuning for a single input configuration.

Source code in vllm/kernels/helion/register.py
def run_autotune(
    self,
    inputs: tuple[Any, ...],
    autotune_effort: str = "quick",
) -> Config:
    """Run autotuning for a single input configuration."""
    extra_kwargs = {
        "autotune_effort": autotune_effort,
        "autotune_ignore_errors": True,
    }
    autotune_kernel = create_helion_decorated_kernel(
        self.raw_kernel_func, self.helion_settings, extra_kwargs
    )
    return autotune_kernel.autotune(inputs)

PresetConfigSearch

Bases: BaseAutotuner

Custom autotuner that uses a preset config selector instead of autotuning.

Source code in vllm/kernels/helion/register.py
class PresetConfigSearch(BaseAutotuner):
    """Custom autotuner that uses a preset config selector instead of autotuning."""

    def __init__(
        self,
        args: tuple[Any, ...],
        config_selector: Callable[[tuple[Any, ...]], Config],
    ):
        self.args = args
        self.config_selector = config_selector

    def autotune(self, *, skip_cache: bool = False) -> Config:
        return self.config_selector(self.args)

register_kernel

register_kernel(
    op_name: str | None = None,
    *,
    config_picker: Callable[
        [tuple[Any, ...], list[str]], str | None
    ],
    fake_impl: Callable | None = None,
    helion_settings: Settings | None = None,
    input_generator: Callable[
        [], dict[str, tuple[Any, ...]]
    ]
    | None = None,
) -> Callable[[Callable], HelionKernelWrapper]

Register a Helion kernel with pre-tuned config selection.

Wraps the kernel function in a HelionKernelWrapper that eagerly builds the configured kernel and (on older PyTorch) registers a custom op.

Parameters:

Name Type Description Default
config_picker Callable[[tuple[Any, ...], list[str]], str | None]

Required. Function with signature (args: tuple, config_keys: list[str]) -> str | None that picks the best config key from available options. Return None to fall back to "default".

Example::

def pick_config(args, config_keys):
    x = args[0]
    hidden_size = x.shape[-1]
    batch_size = x.shape[0]
    for key in config_keys:
        if key == f"hiddensize_{hidden_size}_batchsize_{batch_size}":
            return key
    return "default" if "default" in config_keys else None
required
input_generator Callable[[], dict[str, tuple[Any, ...]]] | None

Optional. Function that returns dict[str, tuple] where each key is a configuration identifier (e.g. "4096", "hidden_4096") and each value is a tuple of arguments to pass to the kernel.

Example::

def generate_inputs():
    return {
        "4096": (torch.randn(4096, device="cuda"), 0.5),
        "8192": (torch.randn(8192, device="cuda"), 0.5),
    }
None
Source code in vllm/kernels/helion/register.py
def register_kernel(
    op_name: str | None = None,
    *,
    config_picker: Callable[[tuple[Any, ...], list[str]], str | None],
    fake_impl: Callable | None = None,
    helion_settings: "helion.Settings | None" = None,
    input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None,
) -> Callable[[Callable], HelionKernelWrapper]:
    """Register a Helion kernel with pre-tuned config selection.

    Wraps the kernel function in a HelionKernelWrapper that eagerly builds
    the configured kernel and (on older PyTorch) registers a custom op.

    Args:
        config_picker: Required. Function with signature
            ``(args: tuple, config_keys: list[str]) -> str | None``
            that picks the best config key from available options.
            Return ``None`` to fall back to ``"default"``.

            Example::

                def pick_config(args, config_keys):
                    x = args[0]
                    hidden_size = x.shape[-1]
                    batch_size = x.shape[0]
                    for key in config_keys:
                        if key == f"hiddensize_{hidden_size}_batchsize_{batch_size}":
                            return key
                    return "default" if "default" in config_keys else None

        input_generator: Optional. Function that returns
            ``dict[str, tuple]`` where each key is a configuration
            identifier (e.g. ``"4096"``, ``"hidden_4096"``) and each
            value is a tuple of arguments to pass to the kernel.

            Example::

                def generate_inputs():
                    return {
                        "4096": (torch.randn(4096, device="cuda"), 0.5),
                        "8192": (torch.randn(8192, device="cuda"), 0.5),
                    }
    """

    def decorator(kernel_func: Callable) -> HelionKernelWrapper:
        final_op_name = op_name if op_name else kernel_func.__name__

        if final_op_name in _REGISTERED_KERNELS:
            raise ValueError(
                f"Helion kernel '{final_op_name}' is already registered. "
                f"Use a different op_name or check for duplicate registrations."
            )

        final_fake_impl = fake_impl
        if final_fake_impl is None:
            final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
            logger.debug(
                "Auto-generated fake_impl for Helion kernel '%s'",
                kernel_func.__name__,
            )

        kernel_wrapper = HelionKernelWrapper(
            raw_kernel_func=kernel_func,
            op_name=final_op_name,
            fake_impl=final_fake_impl,
            config_picker=config_picker,
            helion_settings=helion_settings,
            input_generator=input_generator,
        )

        _REGISTERED_KERNELS[final_op_name] = kernel_wrapper

        logger.info(
            "Registered Helion kernel '%s' as HelionKernelWrapper",
            kernel_func.__name__,
        )

        return kernel_wrapper

    return decorator