vllm.attention.ops.common ¶
CPTritonContext ¶
The CPTritonContext is used to avoid recompilation of the Triton JIT.
Source code in vllm/attention/ops/common.py
_correct_attn_cp_out_kernel ¶
_correct_attn_cp_out_kernel(
outputs_ptr,
new_output_ptr,
lses_ptr,
vlse_ptr,
outputs_stride_B,
outputs_stride_H,
outputs_stride_D,
lses_stride_N,
lses_stride_B,
lses_stride_H,
lse_idx,
HEAD_DIM: constexpr,
N_ROUNDED: constexpr,
)
Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
outputs_ptr | PointerType | Pointer to input tensor of shape [ B, H, D ] | required |
lses_ptr | PointerType | Pointer to input tensor of shape [ N, B, H ] | required |
new_output_ptr | PointerType | Pointer to output tensor of shape [ B, H, D ] | required |
vlse_ptr | PointerType | Pointer to output tensor of shape [ B, H ] | required |
Source code in vllm/attention/ops/common.py
correct_attn_out ¶
correct_attn_out(
out: Tensor,
lses: Tensor,
cp_rank: int,
ctx: CPTritonContext,
) -> tuple[Tensor, Tensor]
Correct the attention output using the all-gathered lses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
out | Tensor | Tensor of shape [ B, H, D ] | required |
lses | Tensor | Tensor of shape [ N, B, H ] | required |
cp_rank | int | Current rank in the context-parallel group | required |
ctx | CPTritonContext | Triton context to avoid recompilation | required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor] | Tuple of (out, lse) with corrected attention and final log-sum-exp. |
Source code in vllm/attention/ops/common.py
cp_lse_ag_out_rs ¶
cp_lse_ag_out_rs(
cp_attn_out: Tensor,
cp_attn_lse: Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
)
cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ]