class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "gdn_attention"
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
)
def __init__(
self,
config: Qwen3NextConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
speculative_config: SpeculativeConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.aux_stream = aux_stream()
self.events = (
[torch.cuda.Event(), torch.cuda.Event()]
if current_platform.is_cuda_alike()
else [None, None]
)
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# we need to create qkvz_proj adaptively here.
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
# Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
# layouts, so we use a factory method to create the projection.
self.in_proj_ba = self.create_ba_proj(
hidden_size=self.hidden_size,
num_v_heads=self.num_v_heads,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.tp_size,
self.tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependent
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size),
)
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[sum((key_dim, key_dim, value_dim, value_dim))],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3-Next stores in_proj_ba as a single fused weight with an
# interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
# each group corresponds to a key-head group. We must use a single
# output shard so that ColumnParallel sharding preserves this
# interleaved structure across TP ranks.
# Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
# its checkpoint has separate in_proj_b and in_proj_a weights.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads * 2],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def fix_query_key_value_ordering(
self,
mixed_qkvz: torch.Tensor,
mixed_ba: torch.Tensor,
):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
"""
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
self.num_k_heads // self.tp_size,
(
self.head_k_dim
+ self.head_k_dim
+ (self.head_v_dim + self.head_v_dim)
* self.num_v_heads
// self.num_k_heads
),
)
new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
self.num_k_heads // self.tp_size,
2 * self.num_v_heads // self.num_k_heads,
)
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
split_arg_list_qkvz = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
]
split_arg_list_ba = [
self.num_v_heads // self.num_k_heads,
self.num_v_heads // self.num_k_heads,
]
# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
# [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
(query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), -1, self.head_v_dim)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)
return query, key, value, z, b, a
def rearrange_mixed_qkv(self, mixed_qkv):
if mixed_qkv is None:
return None, None, None
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim // self.tp_size,
self.key_dim // self.tp_size,
self.value_dim // self.tp_size,
],
dim=-1,
)
query, key = map(
lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
(query, key),
)
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
return query.contiguous(), key.contiguous(), value.contiguous()
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
self.in_proj_qkvz.weight.shape[0],
self.in_proj_ba.weight.shape[0],
self.prefix,
)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
# Note: we should not use torch.empty here like other attention backends,
# see discussions in https://github.com/vllm-project/vllm/pull/28182
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
self.prefix,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling.
During V1 profile runs, ``_forward_core`` returns early because
``attn_metadata`` is ``None``, so the autotuned kernels used by
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
vLLM allocates KV cache using most of the remaining GPU memory.
When the first real inference triggers the autotuner it OOMs
because there is not enough memory left for benchmarking.
This method runs minimal forward passes through
``chunk_gated_delta_rule`` with small dummy tensors to force
autotuning while GPU memory is still plentiful. The autotuner
results are cached globally, so only the first layer incurs
actual benchmarking cost.
Most kernels use a fixed ``BT = chunk_size`` (64), but
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the
prefill (chunked) path needs warming up.
"""
if hasattr(self, "_prefill_kernels_warmed_up"):
return
self._prefill_kernels_warmed_up = True
device = mixed_qkv.device
dtype = mixed_qkv.dtype
num_k_heads = self.num_k_heads // self.tp_size
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
# Other kernels always use BT=chunk_size(64), so their autotune
# cache is populated on the first pass and reused thereafter.
for T in (16, 32, 64):
q = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
k = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)
try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=False,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
except Exception:
logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
)
else:
logger.debug(
"GDN prefill kernel warmup (T=%d) completed for layer %s",
T,
self.prefix,
)
finally:
del q, k, v, g, beta, state, cu_seqlens
torch.accelerator.empty_cache()
def _forward_in_proj(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
lambda: self.in_proj_qkvz(hidden_states)[0],
lambda: self.in_proj_ba(hidden_states)[0],
self.events[0],
self.events[1],
self.aux_stream,
)
return projected_states_qkvz, projected_states_ba
def _forward_core(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
):
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
if (
self.enable_packed_recurrent_decode
and attn_metadata.spec_sequence_masks is None
and attn_metadata.num_prefills == 0
and attn_metadata.num_decodes > 0
):
return self._forward_core_decode_non_spec(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_indx = attn_metadata.spec_token_indx
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
# 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 1.1: Process the multi-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0][
: attn_metadata.num_spec_decodes
],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1),
validate_data=False,
)
# 1.2: Process the remaining part
if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[
: attn_metadata.num_actual_tokens
],
validate_data=True,
)
else:
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
mixed_qkv_non_spec
)
if attn_metadata.num_prefills > 0:
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_non_spec = g
beta_non_spec = beta
else:
g_non_spec = None
beta_non_spec = None
# 2. Recurrent attention
# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = (
fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
a=a,
b=b,
dt_bias=self.dt_bias,
q=query_spec,
k=key_spec,
v=value_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[
: attn_metadata.num_spec_decodes + 1
],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
)
else:
core_attn_out_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = self.chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype
)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
a=a,
b=b,
dt_bias=self.dt_bias,
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[
: attn_metadata.num_decodes + 1
],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
)
)
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
merged_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
elif spec_sequence_masks is not None:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
def _forward_core_decode_non_spec(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv_non_spec,
a=a,
b=b,
A_log=self.A_log,
dt_bias=self.dt_bias,
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
use_qk_l2norm_in_kernel=True,
)
return