vllm.v1.sample.ops.topk_topp_triton ¶
Combined Top-K and Top-P Triton kernels.
Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs using Pivot-based Truncation and Selection" By Park et al. (https://arxiv.org/abs/2602.01518)
_update_min_larger_stats ¶
Update running (min, count) of values above a pivot across tiles.
Tracks the smallest value strictly above a pivot and how many times it occurs. Called once per tile per pivot; the running state is carried across tiles via min_larger / num_min_larger.
Merge rule
- tile min < running min → replace both
- tile min == running min → accumulate count
- tile min > running min → keep running values
Source code in vllm/v1/sample/ops/topk_topp_triton.py
apply_top_k_top_p_triton ¶
apply_top_k_top_p_triton(
logits: Tensor,
k: Tensor | None,
p: Tensor | None,
mask_value: float = float("-inf"),
) -> Tensor
Apply combined top-k and top-p masking using Triton.
Top-k is applied first (by logit value), then top-p is applied to the remaining k values (by probability).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits | Tensor | [batch_size, vocab_size] float32 tensor, modified in-place | required |
k | Tensor | None | [batch_size] int32 tensor of top-k values per row, or None to disable top-k | required |
p | Tensor | None | [batch_size] float32 tensor of top-p values per row (0 to 1), or None to disable top-p | required |
mask_value | float | Value for masked positions (default: -inf) | float('-inf') |
Returns:
| Type | Description |
|---|---|
Tensor | The logits tensor (modified in-place) |
Source code in vllm/v1/sample/ops/topk_topp_triton.py
965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 | |