Skip to content

vllm.v1.attention.backends.flashinfer

Attention layer with FlashInfer.

FLASHINFER_WORKSPACE_BUFFER_SIZE module-attribute

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024

logger module-attribute

logger = init_logger(__name__)

FlashInferBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/flashinfer.py
class FlashInferBackend(AttentionBackend):

    accept_output_buffer: bool = True
    cached_sm100a_supported: Optional[bool] = None

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        return [64, 128, 256]

    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type[FlashInferImpl]:
        return FlashInferImpl

    @staticmethod
    def get_metadata_cls() -> type[FlashInferMetadata]:
        return FlashInferMetadata

    @staticmethod
    def get_builder_cls() -> type[FlashInferMetadataBuilder]:
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

    @staticmethod
    def use_trtllm_decode_attention(
        batch_size: int,
        max_seq_len: int,
        kv_cache_dtype: str,
        num_qo_heads: int,
        num_kv_heads: int,
        attn_head_size: int,
    ) -> bool:
        if FlashInferBackend.cached_sm100a_supported is None:
            FlashInferBackend.cached_sm100a_supported = (
                current_platform.has_device_capability(100))
        if not FlashInferBackend.cached_sm100a_supported:
            return False
        if (num_qo_heads // num_kv_heads > 8
                or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
            return False
        env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
        if env_value is not None:
            logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
                             env_value)
            # Environment variable is set - respect it
            # Making the conditional check for zero because
            # the path is automatically enabled if the batch size condition
            # is satisfied.
            no_use_trtllm = env_value == "0"
            if not no_use_trtllm:
                logger.info_once(
                    "VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
                    "using TRTLLM decode attention.")
            return not no_use_trtllm
        else:
            # Environment variable not set - use auto-detection
            # Only supports attention head size of 128
            use_trtllm = (FlashInferBackend.cached_sm100a_supported
                          and batch_size <= 256 and max_seq_len < 131072
                          and kv_cache_dtype == "auto")
            if use_trtllm:
                logger.warning_once(
                    "Using TRTLLM decode attention (auto-detected).")
        return use_trtllm

    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

cached_sm100a_supported class-attribute instance-attribute

cached_sm100a_supported: Optional[bool] = None

get_builder_cls staticmethod

get_builder_cls() -> type[FlashInferMetadataBuilder]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_builder_cls() -> type[FlashInferMetadataBuilder]:
    return FlashInferMetadataBuilder

get_fp8_dtype_for_flashinfer staticmethod

get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> dtype
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
    if kv_cache_dtype in ("fp8", "fp8_e4m3"):
        return torch.float8_e4m3fn
    elif kv_cache_dtype == "fp8_e5m2":
        return torch.float8_e5m2
    else:
        raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

get_impl_cls staticmethod

get_impl_cls() -> type[FlashInferImpl]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_impl_cls() -> type[FlashInferImpl]:
    return FlashInferImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order() -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
    # `stride_order` indicates the permutation that gets us from
    # `get_kv_cache_shape` to the actual memory layout we want.
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD":
        stride_order = (0, 1, 2, 3, 4)
    elif cache_layout == "HND":
        stride_order = (0, 1, 3, 2, 4)
    else:
        raise ValueError(f"Unknown cache layout format {cache_layout}.")
    return stride_order

get_metadata_cls staticmethod

get_metadata_cls() -> type[FlashInferMetadata]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_metadata_cls() -> type[FlashInferMetadata]:
    return FlashInferMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_name() -> str:
    return "FLASHINFER_VLLM_V1"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.float16, torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
    return [64, 128, 256]

use_trtllm_decode_attention staticmethod

use_trtllm_decode_attention(
    batch_size: int,
    max_seq_len: int,
    kv_cache_dtype: str,
    num_qo_heads: int,
    num_kv_heads: int,
    attn_head_size: int,
) -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def use_trtllm_decode_attention(
    batch_size: int,
    max_seq_len: int,
    kv_cache_dtype: str,
    num_qo_heads: int,
    num_kv_heads: int,
    attn_head_size: int,
) -> bool:
    if FlashInferBackend.cached_sm100a_supported is None:
        FlashInferBackend.cached_sm100a_supported = (
            current_platform.has_device_capability(100))
    if not FlashInferBackend.cached_sm100a_supported:
        return False
    if (num_qo_heads // num_kv_heads > 8
            or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
        return False
    env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
    if env_value is not None:
        logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
                         env_value)
        # Environment variable is set - respect it
        # Making the conditional check for zero because
        # the path is automatically enabled if the batch size condition
        # is satisfied.
        no_use_trtllm = env_value == "0"
        if not no_use_trtllm:
            logger.info_once(
                "VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
                "using TRTLLM decode attention.")
        return not no_use_trtllm
    else:
        # Environment variable not set - use auto-detection
        # Only supports attention head size of 128
        use_trtllm = (FlashInferBackend.cached_sm100a_supported
                      and batch_size <= 256 and max_seq_len < 131072
                      and kv_cache_dtype == "auto")
        if use_trtllm:
            logger.warning_once(
                "Using TRTLLM decode attention (auto-detected).")
    return use_trtllm

validate_head_size classmethod

validate_head_size(head_size: int) -> None
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def validate_head_size(cls, head_size: int) -> None:
    supported_head_sizes = cls.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        attn_type = cls.__name__.removesuffix("Backend")
        raise ValueError(
            f"Head size {head_size} is not supported by {attn_type}. "
            f"Supported head sizes are: {supported_head_sizes}. "
            "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
            "FlexAttention backend which supports all head sizes.")

FlashInferImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/flashinfer.py
class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[dict[str, Any]] = None,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[int] = None,
        use_irope: bool = False,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
        self.use_irope = use_irope

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: shape -
            # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
            # HND: [num_blocks, 2,  num_kv_heads, block_size, head_size]


            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashInferImpl")

        if attn_metadata is None:
            # Profiling run.
            return output

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    self.kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)

        window_left = (self.sliding_window[0]
                       if self.sliding_window is not None else -1)

        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

        stride_order = FlashInferBackend.get_kv_cache_stride_order()
        kv_cache_permute = kv_cache.permute(*stride_order)
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back,
        # according to reorder_batch()
        if prefill_wrapper := attn_metadata.prefill_wrapper:
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
            assert prefill_wrapper is not None
            assert prefill_wrapper._causal
            assert prefill_wrapper._window_left == window_left
            assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                        or 0.0)
            assert prefill_wrapper._sm_scale == self.scale
            prefill_wrapper.run(
                prefill_query,
                kv_cache_permute,
                k_scale=layer._k_scale_float,
                v_scale=layer._v_scale_float,
                out=output[num_decode_tokens:],
            )
        if decode_wrapper := attn_metadata.decode_wrapper:
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
            if not FlashInferBackend.use_trtllm_decode_attention(
                    attn_metadata.num_decodes, attn_metadata.max_seq_len,
                    self.kv_cache_dtype, attn_metadata.num_qo_heads,
                    attn_metadata.num_kv_heads, attn_metadata.head_dim):
                assert decode_wrapper is not None
                assert decode_wrapper._window_left == window_left
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                           or 0.0)
                assert decode_wrapper._sm_scale == self.scale
                decode_wrapper.run(
                    decode_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[:num_decode_tokens],
                )
            else:
                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                if num_decode_tokens > 0:
                    # decode_query may be non-contiguous
                    decode_query = decode_query.contiguous()
                    block_tables_decode = attn_metadata.block_table_tensor[:
                                                                           num_decode_tokens]
                    seq_lens_decode = attn_metadata.seq_lens[:
                                                             num_decode_tokens]

                    assert get_kv_cache_layout() == "HND"
                    assert decode_query.is_contiguous()
                    assert kv_cache_permute.is_contiguous()
                    assert block_tables_decode.is_contiguous()
                    assert seq_lens_decode.is_contiguous()

                    output[:num_decode_tokens] = (
                        trtllm_batch_decode_with_kv_cache(
                            query=decode_query,
                            kv_cache=kv_cache_permute,
                            workspace_buffer=attn_metadata.workspace_buffer,
                            num_heads=self.num_heads,
                            num_kv_heads=self.num_kv_heads,
                            scale=self.scale,
                            block_tables=block_tables_decode,
                            seq_lens=seq_lens_decode,
                            block_size=attn_metadata.page_size,
                            max_seq_len=attn_metadata.max_seq_len,
                            kv_cache_dtype=self.kv_cache_dtype,
                            k_scale=layer._k_scale_float,
                            v_scale=layer._v_scale_float,
                        ))
        return output_padded

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = (-1, -1)

use_irope instance-attribute

use_irope = use_irope

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = DECODER,
    kv_sharing_target_layer_name: Optional[int] = None,
    use_irope: bool = False,
) -> None
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[int] = None,
    use_irope: bool = False,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)
    self.kv_cache_dtype = kv_cache_dtype
    self.logits_soft_cap = logits_soft_cap
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
    self.use_irope = use_irope

    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "FlashInferImpl")

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashInferMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
) -> Tensor

Forward pass with FlashInfer.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
kv_cache Tensor

shape -

required
# NHD

[num_blocks, 2, block_size, num_kv_heads, head_size]

required
# HND

[num_blocks, 2, num_kv_heads, block_size, head_size]

required
attn_metadata FlashInferMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/flashinfer.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlashInferMetadata,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with FlashInfer.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache: shape -
        # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
        # HND: [num_blocks, 2,  num_kv_heads, block_size, head_size]


        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."

    if output_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported"
            " for FlashInferImpl")

    if attn_metadata is None:
        # Profiling run.
        return output

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_actual_tokens = attn_metadata.num_actual_tokens

    if self.kv_sharing_target_layer_name is None:
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens]
        # and value[:num_actual_tokens] because the reshape_and_cache_flash
        # op uses the slot_mapping's shape to determine the number of
        # actual tokens.
        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            kv_cache[:, 0],
            kv_cache[:, 1],
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

        # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
        # to process the cache when the kv_cache_dtype is fp8
        if self.kv_cache_dtype.startswith("fp8"):
            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.kv_cache_dtype)
            kv_cache = kv_cache.view(torch_dtype)

    window_left = (self.sliding_window[0]
                   if self.sliding_window is not None else -1)

    # Inputs and outputs may be padded for CUDA graphs
    query = query[:num_actual_tokens]
    output_padded = output
    output = output[:num_actual_tokens]

    if attn_metadata.use_cascade:
        # Cascade attention (rare case).
        assert attn_metadata.cascade_wrapper is not None
        output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
        return output

    num_decode_tokens = attn_metadata.num_decode_tokens
    num_prefill_tokens = attn_metadata.num_prefill_tokens

    stride_order = FlashInferBackend.get_kv_cache_stride_order()
    kv_cache_permute = kv_cache.permute(*stride_order)
    # Regular attention (common case).
    # Decodes are at the front and prefills are at the back,
    # according to reorder_batch()
    if prefill_wrapper := attn_metadata.prefill_wrapper:
        prefill_query = query[num_decode_tokens:]
        assert prefill_query.shape[0] == num_prefill_tokens
        assert prefill_wrapper is not None
        assert prefill_wrapper._causal
        assert prefill_wrapper._window_left == window_left
        assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                    or 0.0)
        assert prefill_wrapper._sm_scale == self.scale
        prefill_wrapper.run(
            prefill_query,
            kv_cache_permute,
            k_scale=layer._k_scale_float,
            v_scale=layer._v_scale_float,
            out=output[num_decode_tokens:],
        )
    if decode_wrapper := attn_metadata.decode_wrapper:
        decode_query = query[:num_decode_tokens]
        assert decode_query.shape[0] == num_decode_tokens
        if not FlashInferBackend.use_trtllm_decode_attention(
                attn_metadata.num_decodes, attn_metadata.max_seq_len,
                self.kv_cache_dtype, attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads, attn_metadata.head_dim):
            assert decode_wrapper is not None
            assert decode_wrapper._window_left == window_left
            assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                       or 0.0)
            assert decode_wrapper._sm_scale == self.scale
            decode_wrapper.run(
                decode_query,
                kv_cache_permute,
                k_scale=layer._k_scale_float,
                v_scale=layer._v_scale_float,
                out=output[:num_decode_tokens],
            )
        else:
            # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
            if num_decode_tokens > 0:
                # decode_query may be non-contiguous
                decode_query = decode_query.contiguous()
                block_tables_decode = attn_metadata.block_table_tensor[:
                                                                       num_decode_tokens]
                seq_lens_decode = attn_metadata.seq_lens[:
                                                         num_decode_tokens]

                assert get_kv_cache_layout() == "HND"
                assert decode_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert block_tables_decode.is_contiguous()
                assert seq_lens_decode.is_contiguous()

                output[:num_decode_tokens] = (
                    trtllm_batch_decode_with_kv_cache(
                        query=decode_query,
                        kv_cache=kv_cache_permute,
                        workspace_buffer=attn_metadata.workspace_buffer,
                        num_heads=self.num_heads,
                        num_kv_heads=self.num_kv_heads,
                        scale=self.scale,
                        block_tables=block_tables_decode,
                        seq_lens=seq_lens_decode,
                        block_size=attn_metadata.page_size,
                        max_seq_len=attn_metadata.max_seq_len,
                        kv_cache_dtype=self.kv_cache_dtype,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                    ))
    return output_padded

FlashInferMetadata dataclass

Source code in vllm/v1/attention/backends/flashinfer.py
@dataclass
class FlashInferMetadata:

    num_actual_tokens: int  # Number of tokens excluding padding.

    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    qo_indptr: torch.Tensor
    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: torch.Tensor
    # The page indices of the paged kv cache
    paged_kv_indices: torch.Tensor
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_len: torch.Tensor
    # The number of query/output heads
    num_qo_heads: int
    # The number of key/value heads
    num_kv_heads: int
    # The dimension of the attention heads
    head_dim: int
    # Block size of vllm
    page_size: int
    # The data type of the paged kv cache
    kv_data_type: torch.dtype
    # The data type of the query
    q_data_type: torch.dtype

    slot_mapping: torch.Tensor

    # For flashinfer trtllm batch decode
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table_tensor: torch.Tensor
    workspace_buffer: torch.Tensor

    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    # For cascade attention.
    use_cascade: bool
    shared_qo_indptr: Optional[torch.Tensor] = None
    shared_kv_page_indptr: Optional[torch.Tensor] = None
    shared_kv_page_indices: Optional[torch.Tensor] = None
    shared_kv_last_page_len: Optional[torch.Tensor] = None

    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
    cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

    @property
    def query_start_loc(self):
        # The GPUModelRunner expects to be able to access this property.
        return self.qo_indptr

    def __post_init__(self):
        if self.head_dim is not None:
            FlashInferBackend.validate_head_size(self.head_dim)

block_table_tensor instance-attribute

block_table_tensor: Tensor

cascade_wrapper class-attribute instance-attribute

cascade_wrapper: Optional[
    MultiLevelCascadeAttentionWrapper
] = None

decode_wrapper class-attribute instance-attribute

decode_wrapper: Optional[
    BatchDecodeWithPagedKVCacheWrapper
] = None

head_dim instance-attribute

head_dim: int

kv_data_type instance-attribute

kv_data_type: dtype

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_kv_heads instance-attribute

num_kv_heads: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

num_qo_heads instance-attribute

num_qo_heads: int

page_size instance-attribute

page_size: int

paged_kv_indices instance-attribute

paged_kv_indices: Tensor

paged_kv_indptr instance-attribute

paged_kv_indptr: Tensor

paged_kv_last_page_len instance-attribute

paged_kv_last_page_len: Tensor

prefill_wrapper class-attribute instance-attribute

prefill_wrapper: Optional[
    BatchPrefillWithPagedKVCacheWrapper
] = None

q_data_type instance-attribute

q_data_type: dtype

qo_indptr instance-attribute

qo_indptr: Tensor

query_start_loc property

query_start_loc

seq_lens instance-attribute

seq_lens: Tensor

shared_kv_last_page_len class-attribute instance-attribute

shared_kv_last_page_len: Optional[Tensor] = None

shared_kv_page_indices class-attribute instance-attribute

shared_kv_page_indices: Optional[Tensor] = None

shared_kv_page_indptr class-attribute instance-attribute

shared_kv_page_indptr: Optional[Tensor] = None

shared_qo_indptr class-attribute instance-attribute

shared_qo_indptr: Optional[Tensor] = None

slot_mapping instance-attribute

slot_mapping: Tensor

use_cascade instance-attribute

use_cascade: bool

workspace_buffer instance-attribute

workspace_buffer: Tensor

__init__

__init__(
    num_actual_tokens: int,
    qo_indptr: Tensor,
    paged_kv_indptr: Tensor,
    paged_kv_indices: Tensor,
    paged_kv_last_page_len: Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    kv_data_type: dtype,
    q_data_type: dtype,
    slot_mapping: Tensor,
    max_seq_len: int,
    seq_lens: Tensor,
    block_table_tensor: Tensor,
    workspace_buffer: Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    num_prefill_tokens: int,
    use_cascade: bool,
    shared_qo_indptr: Optional[Tensor] = None,
    shared_kv_page_indptr: Optional[Tensor] = None,
    shared_kv_page_indices: Optional[Tensor] = None,
    shared_kv_last_page_len: Optional[Tensor] = None,
    prefill_wrapper: Optional[
        BatchPrefillWithPagedKVCacheWrapper
    ] = None,
    decode_wrapper: Optional[
        BatchDecodeWithPagedKVCacheWrapper
    ] = None,
    cascade_wrapper: Optional[
        MultiLevelCascadeAttentionWrapper
    ] = None,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/attention/backends/flashinfer.py
def __post_init__(self):
    if self.head_dim is not None:
        FlashInferBackend.validate_head_size(self.head_dim)

FlashInferMetadataBuilder

Bases: AttentionMetadataBuilder[FlashInferMetadata]

Source code in vllm/v1/attention/backends/flashinfer.py
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

    def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
                 device: torch.device):
        self.device = device
        self._workspace_buffer = None
        self._prefill_wrapper = None  # Wrapper for prefill/append
        self._decode_wrapper = None  # Wrapper for decode
        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
        self.global_hyperparameters: Optional[PerLayerParameters] = None

        self.vllm_config = vllm_config
        self.cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec

    def reorder_batch(self, input_batch: InputBatch,
                      scheduler_output: SchedulerOutput) -> bool:
        return reorder_batch_to_split_decodes_and_prefills(input_batch,
                                                           scheduler_output,
                                                           decode_threshold=1)

    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.empty(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=self.device)
        return self._workspace_buffer

    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                self._get_workspace_buffer(), get_kv_cache_layout())
        return self._prefill_wrapper

    def _get_decode_wrapper(self):
        if self._decode_wrapper is None:
            num_qo_heads = (
                self.vllm_config.model_config.get_num_attention_heads(
                    self.vllm_config.parallel_config))
            num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
                self.vllm_config.parallel_config)
            use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
                num_qo_heads // num_kv_heads > 4)
            self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
                get_kv_cache_layout(),
                use_tensor_cores=use_tensor_cores)
        return self._decode_wrapper

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
                2, self._get_workspace_buffer(), get_kv_cache_layout())
        return self._cascade_wrapper

    def _plan(self, num_prefills: int, num_decodes: int,
              attn_metadata: FlashInferMetadata):
        if self.global_hyperparameters is None:
            self.global_hyperparameters = infer_global_hyperparameters(
                get_per_layer_parameters(self.vllm_config, FlashInferImpl))
        if attn_metadata.use_cascade:
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
                [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
                [
                    attn_metadata.shared_kv_page_indptr,
                    attn_metadata.paged_kv_indptr
                ],
                [
                    attn_metadata.shared_kv_page_indices,
                    attn_metadata.paged_kv_indices
                ],
                [
                    attn_metadata.shared_kv_last_page_len,
                    attn_metadata.paged_kv_last_page_len
                ],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                causal=True,
                sm_scale=self.global_hyperparameters.sm_scale,
                window_left=self.global_hyperparameters.window_left,
                logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
                q_data_type=attn_metadata.q_data_type,
                kv_data_type=attn_metadata.kv_data_type,
            )
        else:
            # Regular attention (common case).
            # Decodes are at the front and prefills are at the back,
            # according to reorder_batch()
            if num_prefills > 0:
                # Decodes are first so prefills start after the last decode
                prefill_start = num_decodes
                attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
                assert attn_metadata.qo_indptr[prefill_start:].shape[
                    0] == num_prefills + 1
                assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
                    0] == num_prefills + 1
                assert attn_metadata.paged_kv_last_page_len[
                    prefill_start:].shape[0] == num_prefills
                # Since prefill_wrapper.run() will be called with
                # query[num_decode_tokens:] we need to adjust the qo_indptr
                # to be relative to the start of the prefill queries.
                qo_indptr = attn_metadata.qo_indptr[
                    prefill_start:] - attn_metadata.qo_indptr[prefill_start]
                attn_metadata.prefill_wrapper.plan(
                    qo_indptr,
                    attn_metadata.paged_kv_indptr[prefill_start:],
                    attn_metadata.paged_kv_indices,
                    attn_metadata.paged_kv_last_page_len[prefill_start:],
                    attn_metadata.num_qo_heads,
                    attn_metadata.num_kv_heads,
                    attn_metadata.head_dim,
                    attn_metadata.page_size,
                    causal=True,
                    sm_scale=self.global_hyperparameters.sm_scale,
                    window_left=self.global_hyperparameters.window_left,
                    logits_soft_cap=self.global_hyperparameters.
                    logits_soft_cap,
                    q_data_type=attn_metadata.q_data_type,
                    kv_data_type=attn_metadata.kv_data_type,
                )

            if num_decodes > 0:
                attn_metadata.decode_wrapper = self._get_decode_wrapper()
                if not FlashInferBackend.use_trtllm_decode_attention(
                        num_decodes, attn_metadata.max_seq_len,
                        self.cache_config.cache_dtype,
                        attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
                        attn_metadata.head_dim):
                    attn_metadata.decode_wrapper.plan(
                        attn_metadata.paged_kv_indptr[:num_decodes + 1],
                        attn_metadata.paged_kv_indices,
                        attn_metadata.paged_kv_last_page_len[:num_decodes],
                        attn_metadata.num_qo_heads,
                        attn_metadata.num_kv_heads,
                        attn_metadata.head_dim,
                        attn_metadata.page_size,
                        # Disable flashinfer's pos encoding and use vllm's rope.
                        pos_encoding_mode="NONE",
                        sm_scale=self.global_hyperparameters.sm_scale,
                        window_left=self.global_hyperparameters.window_left,
                        logits_soft_cap=self.global_hyperparameters.
                        logits_soft_cap,
                        q_data_type=attn_metadata.q_data_type,
                        kv_data_type=attn_metadata.kv_data_type,
                    )

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashInferMetadata:
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
            split_decodes_and_prefills(common_attn_metadata)

        page_size = self.kv_cache_spec.block_size
        device = self.device
        qo_indptr = common_attn_metadata.query_start_loc
        max_seq_len = common_attn_metadata.seq_lens_cpu.max()
        seq_lens = common_attn_metadata.seq_lens
        block_table_tensor = common_attn_metadata.block_table_tensor

        block_table_bounds = (seq_lens + page_size - 1) // page_size

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # Grab the blocks of the shared prefix from the first request.
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size
            shared_qo_indptr = torch.tensor([0, num_actual_tokens],
                                            dtype=torch.int32,
                                            device=device)
            shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
                                                 dtype=torch.int32,
                                                 device=device)
            shared_kv_page_indices = block_table_tensor[
                0, :num_common_kv_blocks]
            shared_kv_last_page_len = torch.tensor([page_size],
                                                   dtype=torch.int32,
                                                   device=device)
            # Remove the blocks of the shared prefix from all requests.
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
            block_table_bounds -= num_common_kv_blocks
        else:
            shared_qo_indptr = None
            shared_kv_page_indptr = None
            shared_kv_page_indices = None
            shared_kv_last_page_len = None

        mask = (torch.arange(block_table_tensor.size(1),
                             dtype=block_table_tensor.dtype,
                             device=block_table_tensor.device).unsqueeze(0)
                < block_table_bounds.unsqueeze(1))
        paged_kv_indices = block_table_tensor[mask]

        paged_kv_indptr = torch.cat([
            torch.zeros(1,
                        dtype=block_table_bounds.dtype,
                        device=block_table_bounds.device),
            block_table_bounds.cumsum(dim=0, dtype=torch.int32)
        ])

        paged_kv_last_page_len = seq_lens % page_size
        paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
                                             page_size, paged_kv_last_page_len)
        cache_dtype = self.cache_config.cache_dtype
        if cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                cache_dtype)
        else:
            kv_cache_dtype = self.kv_cache_spec.dtype
        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
            qo_indptr=qo_indptr,
            paged_kv_indptr=paged_kv_indptr,
            paged_kv_indices=paged_kv_indices,
            paged_kv_last_page_len=paged_kv_last_page_len,
            num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
                self.vllm_config.parallel_config),
            num_kv_heads=self.kv_cache_spec.num_kv_heads,
            head_dim=self.kv_cache_spec.head_size,
            page_size=page_size,
            kv_data_type=kv_cache_dtype,
            q_data_type=self.vllm_config.model_config.dtype,
            slot_mapping=common_attn_metadata.slot_mapping,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            use_cascade=use_cascade,
            shared_qo_indptr=shared_qo_indptr,
            shared_kv_page_indptr=shared_kv_page_indptr,
            shared_kv_page_indices=shared_kv_page_indices,
            shared_kv_last_page_len=shared_kv_last_page_len,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table_tensor=block_table_tensor,
            workspace_buffer=self._workspace_buffer,
        )

        self._plan(num_prefills, num_decodes, attn_metadata)

        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
        return use_cascade_attention(*args, **kwargs)

_cascade_wrapper instance-attribute

_cascade_wrapper = None

_decode_wrapper instance-attribute

_decode_wrapper = None

_prefill_wrapper instance-attribute

_prefill_wrapper = None

_workspace_buffer instance-attribute

_workspace_buffer = None

cache_config instance-attribute

cache_config = cache_config

device instance-attribute

device = device

global_hyperparameters instance-attribute

global_hyperparameters: Optional[PerLayerParameters] = None

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
             device: torch.device):
    self.device = device
    self._workspace_buffer = None
    self._prefill_wrapper = None  # Wrapper for prefill/append
    self._decode_wrapper = None  # Wrapper for decode
    self._cascade_wrapper = None  # Wrapper for cascade attention

    # Global hyperparameters shared by all attention layers
    self.global_hyperparameters: Optional[PerLayerParameters] = None

    self.vllm_config = vllm_config
    self.cache_config = vllm_config.cache_config
    self.kv_cache_spec = kv_cache_spec

_get_cascade_wrapper

_get_cascade_wrapper()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_cascade_wrapper(self):
    if self._cascade_wrapper is None:
        self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
            2, self._get_workspace_buffer(), get_kv_cache_layout())
    return self._cascade_wrapper

_get_decode_wrapper

_get_decode_wrapper()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_decode_wrapper(self):
    if self._decode_wrapper is None:
        num_qo_heads = (
            self.vllm_config.model_config.get_num_attention_heads(
                self.vllm_config.parallel_config))
        num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
            self.vllm_config.parallel_config)
        use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
            num_qo_heads // num_kv_heads > 4)
        self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
            self._get_workspace_buffer(),
            get_kv_cache_layout(),
            use_tensor_cores=use_tensor_cores)
    return self._decode_wrapper

_get_prefill_wrapper

_get_prefill_wrapper()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_prefill_wrapper(self):
    if self._prefill_wrapper is None:
        self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
            self._get_workspace_buffer(), get_kv_cache_layout())
    return self._prefill_wrapper

_get_workspace_buffer

_get_workspace_buffer()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_workspace_buffer(self):
    if self._workspace_buffer is None:
        self._workspace_buffer = torch.empty(
            FLASHINFER_WORKSPACE_BUFFER_SIZE,
            dtype=torch.uint8,
            device=self.device)
    return self._workspace_buffer

_plan

_plan(
    num_prefills: int,
    num_decodes: int,
    attn_metadata: FlashInferMetadata,
)
Source code in vllm/v1/attention/backends/flashinfer.py
def _plan(self, num_prefills: int, num_decodes: int,
          attn_metadata: FlashInferMetadata):
    if self.global_hyperparameters is None:
        self.global_hyperparameters = infer_global_hyperparameters(
            get_per_layer_parameters(self.vllm_config, FlashInferImpl))
    if attn_metadata.use_cascade:
        attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
        attn_metadata.cascade_wrapper.plan(
            [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
            [
                attn_metadata.shared_kv_page_indptr,
                attn_metadata.paged_kv_indptr
            ],
            [
                attn_metadata.shared_kv_page_indices,
                attn_metadata.paged_kv_indices
            ],
            [
                attn_metadata.shared_kv_last_page_len,
                attn_metadata.paged_kv_last_page_len
            ],
            attn_metadata.num_qo_heads,
            attn_metadata.num_kv_heads,
            attn_metadata.head_dim,
            attn_metadata.page_size,
            causal=True,
            sm_scale=self.global_hyperparameters.sm_scale,
            window_left=self.global_hyperparameters.window_left,
            logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
            q_data_type=attn_metadata.q_data_type,
            kv_data_type=attn_metadata.kv_data_type,
        )
    else:
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back,
        # according to reorder_batch()
        if num_prefills > 0:
            # Decodes are first so prefills start after the last decode
            prefill_start = num_decodes
            attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
            assert attn_metadata.qo_indptr[prefill_start:].shape[
                0] == num_prefills + 1
            assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
                0] == num_prefills + 1
            assert attn_metadata.paged_kv_last_page_len[
                prefill_start:].shape[0] == num_prefills
            # Since prefill_wrapper.run() will be called with
            # query[num_decode_tokens:] we need to adjust the qo_indptr
            # to be relative to the start of the prefill queries.
            qo_indptr = attn_metadata.qo_indptr[
                prefill_start:] - attn_metadata.qo_indptr[prefill_start]
            attn_metadata.prefill_wrapper.plan(
                qo_indptr,
                attn_metadata.paged_kv_indptr[prefill_start:],
                attn_metadata.paged_kv_indices,
                attn_metadata.paged_kv_last_page_len[prefill_start:],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                causal=True,
                sm_scale=self.global_hyperparameters.sm_scale,
                window_left=self.global_hyperparameters.window_left,
                logits_soft_cap=self.global_hyperparameters.
                logits_soft_cap,
                q_data_type=attn_metadata.q_data_type,
                kv_data_type=attn_metadata.kv_data_type,
            )

        if num_decodes > 0:
            attn_metadata.decode_wrapper = self._get_decode_wrapper()
            if not FlashInferBackend.use_trtllm_decode_attention(
                    num_decodes, attn_metadata.max_seq_len,
                    self.cache_config.cache_dtype,
                    attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
                    attn_metadata.head_dim):
                attn_metadata.decode_wrapper.plan(
                    attn_metadata.paged_kv_indptr[:num_decodes + 1],
                    attn_metadata.paged_kv_indices,
                    attn_metadata.paged_kv_last_page_len[:num_decodes],
                    attn_metadata.num_qo_heads,
                    attn_metadata.num_kv_heads,
                    attn_metadata.head_dim,
                    attn_metadata.page_size,
                    # Disable flashinfer's pos encoding and use vllm's rope.
                    pos_encoding_mode="NONE",
                    sm_scale=self.global_hyperparameters.sm_scale,
                    window_left=self.global_hyperparameters.window_left,
                    logits_soft_cap=self.global_hyperparameters.
                    logits_soft_cap,
                    q_data_type=attn_metadata.q_data_type,
                    kv_data_type=attn_metadata.kv_data_type,
                )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashInferMetadata
Source code in vllm/v1/attention/backends/flashinfer.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> FlashInferMetadata:
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
        split_decodes_and_prefills(common_attn_metadata)

    page_size = self.kv_cache_spec.block_size
    device = self.device
    qo_indptr = common_attn_metadata.query_start_loc
    max_seq_len = common_attn_metadata.seq_lens_cpu.max()
    seq_lens = common_attn_metadata.seq_lens
    block_table_tensor = common_attn_metadata.block_table_tensor

    block_table_bounds = (seq_lens + page_size - 1) // page_size

    use_cascade = common_prefix_len > 0
    if use_cascade:
        # Grab the blocks of the shared prefix from the first request.
        assert common_prefix_len % page_size == 0
        num_common_kv_blocks = common_prefix_len // page_size
        shared_qo_indptr = torch.tensor([0, num_actual_tokens],
                                        dtype=torch.int32,
                                        device=device)
        shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
                                             dtype=torch.int32,
                                             device=device)
        shared_kv_page_indices = block_table_tensor[
            0, :num_common_kv_blocks]
        shared_kv_last_page_len = torch.tensor([page_size],
                                               dtype=torch.int32,
                                               device=device)
        # Remove the blocks of the shared prefix from all requests.
        block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
        block_table_bounds -= num_common_kv_blocks
    else:
        shared_qo_indptr = None
        shared_kv_page_indptr = None
        shared_kv_page_indices = None
        shared_kv_last_page_len = None

    mask = (torch.arange(block_table_tensor.size(1),
                         dtype=block_table_tensor.dtype,
                         device=block_table_tensor.device).unsqueeze(0)
            < block_table_bounds.unsqueeze(1))
    paged_kv_indices = block_table_tensor[mask]

    paged_kv_indptr = torch.cat([
        torch.zeros(1,
                    dtype=block_table_bounds.dtype,
                    device=block_table_bounds.device),
        block_table_bounds.cumsum(dim=0, dtype=torch.int32)
    ])

    paged_kv_last_page_len = seq_lens % page_size
    paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
                                         page_size, paged_kv_last_page_len)
    cache_dtype = self.cache_config.cache_dtype
    if cache_dtype.startswith("fp8"):
        kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
            cache_dtype)
    else:
        kv_cache_dtype = self.kv_cache_spec.dtype
    attn_metadata = FlashInferMetadata(
        num_actual_tokens=num_actual_tokens,
        qo_indptr=qo_indptr,
        paged_kv_indptr=paged_kv_indptr,
        paged_kv_indices=paged_kv_indices,
        paged_kv_last_page_len=paged_kv_last_page_len,
        num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
            self.vllm_config.parallel_config),
        num_kv_heads=self.kv_cache_spec.num_kv_heads,
        head_dim=self.kv_cache_spec.head_size,
        page_size=page_size,
        kv_data_type=kv_cache_dtype,
        q_data_type=self.vllm_config.model_config.dtype,
        slot_mapping=common_attn_metadata.slot_mapping,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        use_cascade=use_cascade,
        shared_qo_indptr=shared_qo_indptr,
        shared_kv_page_indptr=shared_kv_page_indptr,
        shared_kv_page_indices=shared_kv_page_indices,
        shared_kv_last_page_len=shared_kv_last_page_len,
        max_seq_len=max_seq_len,
        seq_lens=seq_lens,
        block_table_tensor=block_table_tensor,
        workspace_buffer=self._workspace_buffer,
    )

    self._plan(num_prefills, num_decodes, attn_metadata)

    return attn_metadata

reorder_batch

reorder_batch(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
) -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
def reorder_batch(self, input_batch: InputBatch,
                  scheduler_output: SchedulerOutput) -> bool:
    return reorder_batch_to_split_decodes_and_prefills(input_batch,
                                                       scheduler_output,
                                                       decode_threshold=1)

use_cascade_attention

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
def use_cascade_attention(self, *args, **kwargs) -> bool:
    if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
        # TODO: The cascade wrapper currently does not support setting
        # kv cache dtype to something different from query dtype.
        return False
    return use_cascade_attention(*args, **kwargs)