Skip to content

vllm.model_executor.layers.quantization.modelopt

KV_CACHE_QUANT_ALGOS module-attribute

KV_CACHE_QUANT_ALGOS = ['FP8']

QUANT_ALGOS module-attribute

QUANT_ALGOS = ['FP8', 'NVFP4']

logger module-attribute

logger = init_logger(__name__)

ModelOptFp8Config

Bases: QuantizationConfig

Config class for ModelOpt FP8.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        kv_cache_quant_method: Optional[str] = None,
        exclude_modules: Optional[list[str]] = None,
    ) -> None:
        super().__init__()
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        self.kv_cache_quant_method = kv_cache_quant_method
        self.exclude_modules = exclude_modules
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                           " the format is experimental and could change.")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
        quant_config = cls.get_from_keys(config, ["quantization"])
        quant_method = quant_config["quant_algo"]
        kv_cache_quant_method = cls.get_from_keys(
            config, ["quantization"]).get("kv_cache_quant_algo")
        exclude_modules = cls.get_from_keys(
            config, ["quantization"]).get("exclude_modules")

        if quant_method not in QUANT_ALGOS:
            raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
                             " quantizations in vLLM. Please check the "
                             "`hf_quant_config.json` file for your model's "
                             "quant configuration.")
        is_checkpoint_fp8_serialized = ("FP8" in quant_method)

        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
                   exclude_modules)

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

        # Check if any excluded module matches the prefix
        for module in self.exclude_modules:
            if (module in prefix
                    or (prefix.startswith("language_model.")
                        and module in prefix.removeprefix("language_model."))):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptFp8MoEMethod(self)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

kv_cache_quant_method instance-attribute

kv_cache_quant_method = kv_cache_quant_method

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None:
    super().__init__()
    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
    self.kv_cache_quant_method = kv_cache_quant_method
    self.exclude_modules = exclude_modules
    if is_checkpoint_fp8_serialized:
        logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                       " the format is experimental and could change.")

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptFp8Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
    quant_config = cls.get_from_keys(config, ["quantization"])
    quant_method = quant_config["quant_algo"]
    kv_cache_quant_method = cls.get_from_keys(
        config, ["quantization"]).get("kv_cache_quant_algo")
    exclude_modules = cls.get_from_keys(
        config, ["quantization"]).get("exclude_modules")

    if quant_method not in QUANT_ALGOS:
        raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
                         " quantizations in vLLM. Please check the "
                         "`hf_quant_config.json` file for your model's "
                         "quant configuration.")
    is_checkpoint_fp8_serialized = ("FP8" in quant_method)

    return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
               exclude_modules)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 89

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import
    if isinstance(layer, LinearBase):
        if self.is_layer_excluded(prefix):
            return UnquantizedLinearMethod()
        return ModelOptFp8LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptFp8MoEMethod(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

is_layer_excluded

is_layer_excluded(prefix: str) -> bool

Check if a layer should be excluded from quantization.

This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the module name (without the language_model prefix) is in the exclude list.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str) -> bool:
    """
    Check if a layer should be excluded from quantization.

    This method handles both regular models and multimodal models that use
    the language_model prefix. For multimodal models, it checks if the
    module name (without the language_model prefix) is in the exclude list.
    """
    if self.exclude_modules is None:
        return False

    # Check if any excluded module matches the prefix
    for module in self.exclude_modules:
        if (module in prefix
                or (prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model."))):
            return True
    return False

ModelOptFp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Union[ModelOptFp8Config,
                                           ModelOptNvFp4Config]):
        super().__init__(quant_config)

__init__

__init__(
    quant_config: Union[
        ModelOptFp8Config, ModelOptNvFp4Config
    ],
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: Union[ModelOptFp8Config,
                                       ModelOptNvFp4Config]):
    super().__init__(quant_config)

ModelOptFp8LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale. Future support might be added for dynamic
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn datatype
        Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config):
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
                layer.weight, layer.weight_scale, layer.logical_widths)
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
        layer.input_scale = Parameter(layer.input_scale.max(),
                                      requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     input_scale=layer.input_scale,
                                     bias=bias)

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=True, act_quant_group_shape=PER_TENSOR
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptFp8Config):
    self.quant_config = quant_config
    self.fp8_linear = Fp8LinearOp(
        act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return self.fp8_linear.apply(input=x,
                                 weight=layer.weight,
                                 weight_scale=layer.weight_scale,
                                 input_scale=layer.input_scale,
                                 bias=bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)
    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=weight_dtype),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        weight_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                               weight_loader=weight_loader)
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)
        # INPUT SCALE
        scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                        weight_loader=weight_loader)

        scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:
    weight = layer.weight
    max_w_scale = layer.weight_scale.max()
    if not (layer.weight_scale == layer.weight_scale[0]).all():
        max_w_scale, weight = requantize_with_max_scale(
            layer.weight, layer.weight_scale, layer.logical_widths)
    layer.weight = Parameter(weight.t(), requires_grad=False)
    layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
    layer.input_scale = Parameter(layer.input_scale.max(),
                                  requires_grad=False)

ModelOptFp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config):
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            cutlass_fp8_supported)
        self.cutlass_fp8_supported = cutlass_fp8_supported()

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):

        # Use FP8 dtype if checkpoint is serialized
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             2 * intermediate_size_per_partition,
                             hidden_size,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             hidden_size,
                             intermediate_size_per_partition,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

        layer.w13_weight = Parameter(layer.w13_weight.data,
                                     requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            per_tensor_dequantize)

        # Handle scale parameters
        if hasattr(layer,
                   "w13_weight_scale") and layer.w13_weight_scale is not None:
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:

                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            _,
                        ) = scaled_fp8_quant(dq_weight,
                                             max_w13_scales[expert_id])

                        start += intermediate_size

                # Update the scale parameter to be per-expert
                layer.w13_weight_scale = Parameter(max_w13_scales,
                                                   requires_grad=False)
            else:
                layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                                   requires_grad=False)

        if hasattr(layer,
                   "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                              requires_grad=False)
        # Input scales must be equal for each expert in fp8 MoE layers.
        if hasattr(layer,
                   "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                              requires_grad=False)
        if hasattr(layer,
                   "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                             requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

        # Expert selection
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
        from vllm.model_executor.layers.fused_moe.fused_moe import (
            fused_experts)
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            use_fp8_w8a8=True,
            per_channel_quant=False,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

cutlass_fp8_supported instance-attribute

cutlass_fp8_supported = cutlass_fp8_supported()

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptFp8Config):
    self.quant_config = quant_config
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        cutlass_fp8_supported)
    self.cutlass_fp8_supported = cutlass_fp8_supported()

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

    # Expert selection
    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias,
    )
    from vllm.model_executor.layers.fused_moe.fused_moe import (
        fused_experts)
    return fused_experts(
        x,
        layer.w13_weight,
        layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        use_fp8_w8a8=True,
        per_channel_quant=False,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):

    # Use FP8 dtype if checkpoint is serialized
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)
    weight_loader = extra_weight_attrs.get("weight_loader")

    w13_weight = ModelWeightParameter(
        data=torch.empty(num_experts,
                         2 * intermediate_size_per_partition,
                         hidden_size,
                         dtype=weight_dtype),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight", w13_weight)

    w2_weight = ModelWeightParameter(
        data=torch.empty(num_experts,
                         hidden_size,
                         intermediate_size_per_partition,
                         dtype=weight_dtype),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight", w2_weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
                (num_experts, 2),
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        # Set weight loader attributes for scales
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Process FP8 MoE weights after loading from serialized checkpoint. Only supports pre-quantized checkpoints with FP8 weights and scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Process FP8 MoE weights after loading from serialized checkpoint.
    Only supports pre-quantized checkpoints with FP8 weights and scales.
    """

    layer.w13_weight = Parameter(layer.w13_weight.data,
                                 requires_grad=False)
    layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

    from vllm._custom_ops import scaled_fp8_quant
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        per_tensor_dequantize)

    # Handle scale parameters
    if hasattr(layer,
               "w13_weight_scale") and layer.w13_weight_scale is not None:
        # Fp8 moe kernel needs single weight scale for w13 per expert.
        # We take the max of the w1 and w3 scales
        # then dequant and requant each expert.
        if layer.w13_weight_scale.dim() == 2:

            # Get the maximum scale across w1 and w3 for each expert
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values

            # Requantize each expert's weights using the combined scale
            # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
            # where the first intermediate_size rows are w1, the next are w3
            intermediate_size = layer.w13_weight.shape[1] // 2
            for expert_id in range(layer.w13_weight.shape[0]):
                start = 0
                for shard_id in range(2):  # w1 and w3
                    # Dequantize using the original scale for this shard
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    intermediate_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    # Requantize using the combined max scale

                    (
                        layer.w13_weight[expert_id][start:start +
                                                    intermediate_size, :],
                        _,
                    ) = scaled_fp8_quant(dq_weight,
                                         max_w13_scales[expert_id])

                    start += intermediate_size

            # Update the scale parameter to be per-expert
            layer.w13_weight_scale = Parameter(max_w13_scales,
                                               requires_grad=False)
        else:
            layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                               requires_grad=False)

    if hasattr(layer,
               "w2_weight_scale") and layer.w2_weight_scale is not None:
        layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                          requires_grad=False)
    # Input scales must be equal for each expert in fp8 MoE layers.
    if hasattr(layer,
               "w13_input_scale") and layer.w13_input_scale is not None:
        layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                          requires_grad=False)
    if hasattr(layer,
               "w2_input_scale") and layer.w2_input_scale is not None:
        layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                         requires_grad=False)

ModelOptNvFp4Config

Bases: QuantizationConfig

Config class for ModelOpt FP4.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
        kv_cache_quant_algo: str,
        exclude_modules: list[str],
        group_size: int = 16,
    ) -> None:
        super().__init__()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
                " the format is experimental and could change in future.")

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt_fp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
        quant_config = cls.get_from_keys(config, ["quantization"])
        quant_method = quant_config["quant_algo"]
        if quant_method not in QUANT_ALGOS:
            raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
                             " quantizations in vLLM. Please check the "
                             "`hf_quant_config.json` file for your model's "
                             "quant configuration.")
        is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
        if ("group_size" and "kv_cache_quant_algo"
                and "exclude_modules") not in quant_config:
            raise ValueError("NVFP4 quantization requires group size and "
                             "kv_cache_quant_algo specified in "
                             "hf_quant_config.json")
        kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
        group_size = quant_config["group_size"]
        exclude_modules = quant_config["exclude_modules"]
        return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
                   exclude_modules, group_size)

    def is_layer_excluded(self, prefix: str, exclude_modules: list):
        import regex as re
        for pattern in exclude_modules:
            regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
            if re.fullmatch(regex_str, prefix):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
            if (is_layer_skipped(prefix, self.exclude_modules)
                    or self.is_layer_excluded(prefix, self.exclude_modules)):
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptNvFp4FusedMoE(self)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

group_size instance-attribute

group_size = group_size

is_checkpoint_nvfp4_serialized instance-attribute

is_checkpoint_nvfp4_serialized = (
    is_checkpoint_nvfp4_serialized
)

kv_cache_quant_algo instance-attribute

kv_cache_quant_algo = kv_cache_quant_algo

__init__

__init__(
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: str,
    exclude_modules: list[str],
    group_size: int = 16,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: str,
    exclude_modules: list[str],
    group_size: int = 16,
) -> None:
    super().__init__()
    self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
    if is_checkpoint_nvfp4_serialized:
        logger.warning(
            "Detected ModelOpt NVFP4 checkpoint. Please note that"
            " the format is experimental and could change in future.")

        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptNvFp4Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
    quant_config = cls.get_from_keys(config, ["quantization"])
    quant_method = quant_config["quant_algo"]
    if quant_method not in QUANT_ALGOS:
        raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
                         " quantizations in vLLM. Please check the "
                         "`hf_quant_config.json` file for your model's "
                         "quant configuration.")
    is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
    if ("group_size" and "kv_cache_quant_algo"
            and "exclude_modules") not in quant_config:
        raise ValueError("NVFP4 quantization requires group size and "
                         "kv_cache_quant_algo specified in "
                         "hf_quant_config.json")
    kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
    group_size = quant_config["group_size"]
    exclude_modules = quant_config["exclude_modules"]
    return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
               exclude_modules, group_size)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt_fp4"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import
    if isinstance(layer, LinearBase):
        if (is_layer_skipped(prefix, self.exclude_modules)
                or self.is_layer_excluded(prefix, self.exclude_modules)):
            return UnquantizedLinearMethod()
        return ModelOptNvFp4LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptNvFp4FusedMoE(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

is_layer_excluded

is_layer_excluded(prefix: str, exclude_modules: list)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str, exclude_modules: list):
    import regex as re
    for pattern in exclude_modules:
        regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
        if re.fullmatch(regex_str, prefix):
            return True
    return False

ModelOptNvFp4FusedMoE

Bases: FusedMoEMethodBase

MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config

Source code in vllm/model_executor/layers/quantization/modelopt.py
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
    Args:
        quant_config: NVFP4 Quant Config
    """

    def __init__(self, quant_config: ModelOptNvFp4Config):
        self.quant_config = quant_config
        self.cutlass_nvfp4_supported = cutlass_fp4_supported()
        self.use_marlin = False
        self.allow_flashinfer_cutlass = False

        if envs.VLLM_USE_FLASHINFER_MOE:
            if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
               and current_platform.is_device_capability(100):
                logger.info_once(
                    "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
                self.allow_flashinfer_cutlass = True
            else:
                logger.warning_once(
                    "Flashinfer CUTLASS Fused MoE not supported "
                    "or found on the current platform.")

        if not self.cutlass_nvfp4_supported:
            if is_fp4_marlin_supported():
                self.use_marlin = True
            else:
                raise ValueError("Current platform does not support NVFP4"
                                 " quantization. Please use Blackwell and"
                                 " above.")

        self.fused_experts = None  # type: ignore

    def maybe_swap_experts_impl(
        self,
        moe_parallel_config: FusedMoEParallelConfig,
    ):
        if not self.allow_flashinfer_cutlass:
            return

        logger.debug_once("FlashInferExperts")
        # default to TP/EP case only

        experts_kwargs: dict[str, Any] = {
            "use_nvfp4_w4a4": True,
            "use_dp": moe_parallel_config.dp_size > 1,
            "ep_rank": moe_parallel_config.ep_rank,
            "ep_size": moe_parallel_config.ep_size,
            "tp_rank": moe_parallel_config.tp_rank,
            "tp_size": moe_parallel_config.tp_size,
        }

        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            FlashInferExperts)
        experts = FlashInferExperts(**experts_kwargs)
        self.fused_experts = mk.FusedMoEModularKernel(
            FlashInferCutlassMoEPrepareAndFinalize(
                quant_dtype=torch.uint8,
                #meaning 2x e2m1 packed in one, kernel requirement
            ),
            experts,
        )

    # This method update self.fused_experts
    # only prepare_finalize is not None call select_gemm_impl
    # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
    # when it's not called(TP case), we still have 2 kernels to use.
    def select_gemm_impl(self, prepare_finalize,
                         moe) -> mk.FusedMoEPermuteExpertsUnpermute:

        assert moe is not None
        assert prepare_finalize is not None
        experts = None
        all2all_manager = get_ep_group().device_communicator.all2all_manager
        assert all2all_manager is not None
        if self.allow_flashinfer_cutlass:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                FlashInferExperts)
            logger.debug_once("Using FlashInferExperts")
            experts = FlashInferExperts(
                use_nvfp4_w4a4=True,
                use_dp=moe.moe_parallel_config.dp_size > 1,
                ep_rank=moe.moe_parallel_config.ep_rank,
                ep_size=moe.moe_parallel_config.ep_size,
                tp_rank=moe.moe_parallel_config.tp_rank,
                tp_size=moe.moe_parallel_config.tp_size,
            )
        else:
            assert moe.dp_size > 1
            logger.debug_once("Using CutlassExpertsFp4")
            # Currently CutlassExpertsFp4 doesn't support DP
            raise ValueError(
                "CutlassExpertsFp4 doesn't support DP. "
                "Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
                " backend instead.")

        return experts

    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")

        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition //
                self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        w13_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, 2, dtype=torch.float32),
                                                  weight_loader=weight_loader)
        layer.register_parameter("w13_input_scale", w13_input_scale)

        w2_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def swizzle_blockscale(self, scale: torch.tensor):
        assert (scale.dtype == torch.float8_e4m3fn)
        # Pad and blockwise interleave weight_scale
        scale_ndim = scale.ndim
        if scale.ndim == 2:
            scale = scale.unsqueeze(0)
        assert scale.ndim == 3
        B, M, K = scale.shape
        round_up_multiple = lambda x, m: (x + m - 1) // m * m
        M_padded = round_up_multiple(M, 128)
        K_padded = round_up_multiple(K, 4)
        padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
        padded_scale[:B, :M, :K] = scale
        batches, rows, cols = padded_scale.shape
        assert rows % 128 == 0
        assert cols % 4 == 0
        padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                            cols // 4, 4)
        swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
        swizzled_scale = swizzled_scale.contiguous().cuda()
        return (swizzled_scale.reshape(M, K)
                if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # GEMM 1
        # The FlashInfer Cutlass fused MoE kernel expects the combined weights
        # to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

        if self.allow_flashinfer_cutlass:
            dim = -2
            size = gemm1_weight.size(dim)
            assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
            half = size // 2

            # Reorder weight
            w1, w3 = gemm1_weight.split(half, dim=dim)
            gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()

            # Reorder scale
            s1, s3 = gemm1_weight_scale.split(half, dim=dim)
            gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
        layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                           requires_grad=False)

        if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                              layer.w13_weight_scale_2[:, 1]):
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
                "Accuracy may be affected.")

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                             requires_grad=False)

        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
            torch.float32)
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
            requires_grad=False)

        assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w13_blockscale_swizzled = self.swizzle_blockscale(
            layer.w13_weight_scale)

        layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
                                                  requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
            (1 / w13_input_scale).to(torch.float32), requires_grad=False)

        # GEMM 2
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
            requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

        assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)

        layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
                                                 requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
            del layer.w13_blockscale_swizzled
            del layer.w2_blockscale_swizzled

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ):
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
        assert activation == "silu", "Only SiLU activation is supported."

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)

        if self.use_marlin:
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

        if self.fused_experts is None:
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
            from vllm.model_executor.layers.fused_moe.cutlass_moe import (
                cutlass_moe_fp4)
            out = cutlass_moe_fp4(
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                w1_blockscale=layer.w13_blockscale_swizzled,
                w2_blockscale=layer.w2_blockscale_swizzled,
                g1_alphas=layer.g1_alphas,
                g2_alphas=layer.g2_alphas,
                a1_gscale=layer.w13_input_scale_quant,
                a2_gscale=layer.w2_input_scale_quant,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
                device=x.device,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input)
        else:
            # TP or DP case
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                is_valid_flashinfer_cutlass_fused_moe)
            assert is_valid_flashinfer_cutlass_fused_moe(
                x, layer.w13_weight, layer.w2_weight), (
                    "Flashinfer CUTLASS Fused MoE not applicable!")

            a1_gscale = torch.min(layer.w13_input_scale_quant)
            a2_gscale = torch.min(layer.w2_input_scale_quant)
            extra_expert_args = {
                'g1_alphas': layer.g1_alphas,
                'g2_alphas': layer.g2_alphas,
                'out_dtype': x.dtype,
                # Avoid confusion with a1_scale and a2_scale
                # where are batch size related.
                'a1_gscale': a1_gscale,
                'a2_gscale': a2_gscale,
            }
            extra_prepare_args = {
                'use_dp': layer.dp_size > 1,
                'local_tokens': x.shape[0],
                'a1_gscale': a1_gscale,
            }
            extra_finalize_args = {
                'use_dp': layer.dp_size > 1,
                'local_tokens': x.shape[0],
            }

            out = self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_scale=layer.w13_blockscale_swizzled,
                w2_scale=layer.w2_blockscale_swizzled,
                apply_router_weight_on_input=apply_router_weight_on_input,
                extra_expert_args=extra_expert_args,
                extra_prepare_args=extra_prepare_args,
                extra_finalize_args=extra_finalize_args,
            )
        return out

allow_flashinfer_cutlass instance-attribute

allow_flashinfer_cutlass = False

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_fp4_supported()

fused_experts instance-attribute

fused_experts = None

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = False

__init__

__init__(quant_config: ModelOptNvFp4Config)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptNvFp4Config):
    self.quant_config = quant_config
    self.cutlass_nvfp4_supported = cutlass_fp4_supported()
    self.use_marlin = False
    self.allow_flashinfer_cutlass = False

    if envs.VLLM_USE_FLASHINFER_MOE:
        if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
           and current_platform.is_device_capability(100):
            logger.info_once(
                "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
            self.allow_flashinfer_cutlass = True
        else:
            logger.warning_once(
                "Flashinfer CUTLASS Fused MoE not supported "
                "or found on the current platform.")

    if not self.cutlass_nvfp4_supported:
        if is_fp4_marlin_supported():
            self.use_marlin = True
        else:
            raise ValueError("Current platform does not support NVFP4"
                             " quantization. Please use Blackwell and"
                             " above.")

    self.fused_experts = None  # type: ignore

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
):
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
    assert activation == "silu", "Only SiLU activation is supported."

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias)

    if self.use_marlin:
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            global_scale1=layer.w13_weight_scale_2,
            global_scale2=layer.w2_weight_scale_2,
            quant_type_id=scalar_types.float4_e2m1f.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)

    if self.fused_experts is None:
        # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
        # only (no EP).
        from vllm.model_executor.layers.fused_moe.cutlass_moe import (
            cutlass_moe_fp4)
        out = cutlass_moe_fp4(
            a=x,
            w1_fp4=layer.w13_weight,
            w2_fp4=layer.w2_weight,
            w1_blockscale=layer.w13_blockscale_swizzled,
            w2_blockscale=layer.w2_blockscale_swizzled,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            m=x.shape[0],
            n=layer.w2_weight.shape[2] * 2,
            k=x.shape[1],
            e=layer.w13_weight.shape[0],
            device=x.device,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input)
    else:
        # TP or DP case
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            is_valid_flashinfer_cutlass_fused_moe)
        assert is_valid_flashinfer_cutlass_fused_moe(
            x, layer.w13_weight, layer.w2_weight), (
                "Flashinfer CUTLASS Fused MoE not applicable!")

        a1_gscale = torch.min(layer.w13_input_scale_quant)
        a2_gscale = torch.min(layer.w2_input_scale_quant)
        extra_expert_args = {
            'g1_alphas': layer.g1_alphas,
            'g2_alphas': layer.g2_alphas,
            'out_dtype': x.dtype,
            # Avoid confusion with a1_scale and a2_scale
            # where are batch size related.
            'a1_gscale': a1_gscale,
            'a2_gscale': a2_gscale,
        }
        extra_prepare_args = {
            'use_dp': layer.dp_size > 1,
            'local_tokens': x.shape[0],
            'a1_gscale': a1_gscale,
        }
        extra_finalize_args = {
            'use_dp': layer.dp_size > 1,
            'local_tokens': x.shape[0],
        }

        out = self.fused_experts(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,  # TODO(shuw): fix later, now output is high prec
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_blockscale_swizzled,
            w2_scale=layer.w2_blockscale_swizzled,
            apply_router_weight_on_input=apply_router_weight_on_input,
            extra_expert_args=extra_expert_args,
            extra_prepare_args=extra_prepare_args,
            extra_finalize_args=extra_finalize_args,
        )
    return out

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")

    layer.num_experts = num_experts
    layer.params_dtype = params_dtype
    layer.quant_config = self.quant_config
    weight_dtype = torch.uint8
    weight_scale_dtype = torch.float8_e4m3fn
    weight_loader = extra_weight_attrs.get("weight_loader")
    # GEMM 1
    w13_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // 2,
            dtype=weight_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight", w13_weight)

    # GEMM 2
    w2_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // 2,
            dtype=weight_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight", w2_weight)

    w13_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // self.quant_config.group_size,
            dtype=weight_scale_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight_scale", w13_weight_scale)

    w2_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition //
            self.quant_config.group_size,
            dtype=weight_scale_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

    w13_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, 2, dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

    w2_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

    w13_input_scale = PerTensorScaleParameter(data=torch.empty(
        num_experts, 2, dtype=torch.float32),
                                              weight_loader=weight_loader)
    layer.register_parameter("w13_input_scale", w13_input_scale)

    w2_input_scale = PerTensorScaleParameter(data=torch.empty(
        num_experts, dtype=torch.float32),
                                             weight_loader=weight_loader)
    layer.register_parameter("w2_input_scale", w2_input_scale)

maybe_swap_experts_impl

maybe_swap_experts_impl(
    moe_parallel_config: FusedMoEParallelConfig,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def maybe_swap_experts_impl(
    self,
    moe_parallel_config: FusedMoEParallelConfig,
):
    if not self.allow_flashinfer_cutlass:
        return

    logger.debug_once("FlashInferExperts")
    # default to TP/EP case only

    experts_kwargs: dict[str, Any] = {
        "use_nvfp4_w4a4": True,
        "use_dp": moe_parallel_config.dp_size > 1,
        "ep_rank": moe_parallel_config.ep_rank,
        "ep_size": moe_parallel_config.ep_size,
        "tp_rank": moe_parallel_config.tp_rank,
        "tp_size": moe_parallel_config.tp_size,
    }

    from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
        FlashInferExperts)
    experts = FlashInferExperts(**experts_kwargs)
    self.fused_experts = mk.FusedMoEModularKernel(
        FlashInferCutlassMoEPrepareAndFinalize(
            quant_dtype=torch.uint8,
            #meaning 2x e2m1 packed in one, kernel requirement
        ),
        experts,
    )

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # GEMM 1
    # The FlashInfer Cutlass fused MoE kernel expects the combined weights
    # to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
    gemm1_weight = layer.w13_weight.data
    gemm1_weight_scale = layer.w13_weight_scale.data

    if self.allow_flashinfer_cutlass:
        dim = -2
        size = gemm1_weight.size(dim)
        assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
        half = size // 2

        # Reorder weight
        w1, w3 = gemm1_weight.split(half, dim=dim)
        gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()

        # Reorder scale
        s1, s3 = gemm1_weight_scale.split(half, dim=dim)
        gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()

    layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
    layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                       requires_grad=False)

    if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                          layer.w13_weight_scale_2[:, 1]):
        logger.warning_once(
            "w1_weight_scale_2 must match w3_weight_scale_2. "
            "Accuracy may be affected.")

    w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
    layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                         requires_grad=False)

    w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
        torch.float32)
    layer.g1_alphas = Parameter(
        (w13_input_scale * w13_weight_scale_2).to(torch.float32),
        requires_grad=False)

    assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
        "Expected weight_scale.dim(1) to be divisible by 16")
    assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
        "Weight Blockscale must be represented as FP8-E4M3")
    w13_blockscale_swizzled = self.swizzle_blockscale(
        layer.w13_weight_scale)

    layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
                                              requires_grad=False)

    # This is for quantization, so we need to invert it.
    layer.w13_input_scale_quant = Parameter(
        (1 / w13_input_scale).to(torch.float32), requires_grad=False)

    # GEMM 2
    layer.g2_alphas = Parameter(
        (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
        requires_grad=False)

    # This is for quantization, so we need to invert it.
    layer.w2_input_scale_quant = Parameter(
        (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

    assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
        "Expected weight_scale.dim(1) to be divisible by 16")
    assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
        "Weight Blockscale must be represented as FP8-E4M3")
    w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)

    layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
                                             requires_grad=False)
    layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

    if self.use_marlin:
        prepare_moe_fp4_layer_for_marlin(layer)
        del layer.g1_alphas
        del layer.g2_alphas
        del layer.w13_input_scale_quant
        del layer.w2_input_scale_quant
        del layer.w13_blockscale_swizzled
        del layer.w2_blockscale_swizzled

select_gemm_impl

select_gemm_impl(
    prepare_finalize, moe
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/modelopt.py
def select_gemm_impl(self, prepare_finalize,
                     moe) -> mk.FusedMoEPermuteExpertsUnpermute:

    assert moe is not None
    assert prepare_finalize is not None
    experts = None
    all2all_manager = get_ep_group().device_communicator.all2all_manager
    assert all2all_manager is not None
    if self.allow_flashinfer_cutlass:
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            FlashInferExperts)
        logger.debug_once("Using FlashInferExperts")
        experts = FlashInferExperts(
            use_nvfp4_w4a4=True,
            use_dp=moe.moe_parallel_config.dp_size > 1,
            ep_rank=moe.moe_parallel_config.ep_rank,
            ep_size=moe.moe_parallel_config.ep_size,
            tp_rank=moe.moe_parallel_config.tp_rank,
            tp_size=moe.moe_parallel_config.tp_size,
        )
    else:
        assert moe.dp_size > 1
        logger.debug_once("Using CutlassExpertsFp4")
        # Currently CutlassExpertsFp4 doesn't support DP
        raise ValueError(
            "CutlassExpertsFp4 doesn't support DP. "
            "Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
            " backend instead.")

    return experts

swizzle_blockscale

swizzle_blockscale(scale: tensor)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def swizzle_blockscale(self, scale: torch.tensor):
    assert (scale.dtype == torch.float8_e4m3fn)
    # Pad and blockwise interleave weight_scale
    scale_ndim = scale.ndim
    if scale.ndim == 2:
        scale = scale.unsqueeze(0)
    assert scale.ndim == 3
    B, M, K = scale.shape
    round_up_multiple = lambda x, m: (x + m - 1) // m * m
    M_padded = round_up_multiple(M, 128)
    K_padded = round_up_multiple(K, 4)
    padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
    padded_scale[:B, :M, :K] = scale
    batches, rows, cols = padded_scale.shape
    assert rows % 128 == 0
    assert cols % 4 == 0
    padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                        cols // 4, 4)
    swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
    swizzled_scale = swizzled_scale.contiguous().cuda()
    return (swizzled_scale.reshape(M, K)
            if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

uses_weight_scale_2_pattern

uses_weight_scale_2_pattern() -> bool

FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def uses_weight_scale_2_pattern(self) -> bool:
    """
    FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
    """
    return True

ModelOptNvFp4LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure:

input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptNvFp4Config):
        self.quant_config = quant_config
        self.cutlass_nvfp4_supported = cutlass_fp4_supported()
        self.use_marlin = False

        if not self.cutlass_nvfp4_supported:
            if is_fp4_marlin_supported():
                self.use_marlin = True
            else:
                raise ValueError("Current platform does not support NVFP4"
                                 " quantization. Please use Blackwell and"
                                 " above.")

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if (input_size_per_partition % 16 != 0):
            raise ValueError("Unsupported model when in features size is "
                             "not multiple of 16")
        # The nvfp4 weight is still represented as
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_nvfp4_serialized
                        else params_dtype)
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
                dtype=torch.uint8),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # Input Weight Scale
        input_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                              weight_loader=weight_loader)
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
        weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
        weight_scale = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // self.quant_config.group_size,
            dtype=weight_dtype,
        ),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)

        layer.register_parameter("weight_scale", weight_scale)

    def swizzle_blockscale(self, scale: torch.tensor):
        assert (scale.dtype == torch.float8_e4m3fn)
        # Pad and blockwise interleave weight_scale
        scale_ndim = scale.ndim
        if scale.ndim == 2:
            scale = scale.unsqueeze(0)
        assert scale.ndim == 3
        B, M, K = scale.shape
        round_up_multiple = lambda x, m: (x + m - 1) // m * m
        M_padded = round_up_multiple(M, 128)
        K_padded = round_up_multiple(K, 4)
        padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
        padded_scale[:B, :M, :K] = scale
        batches, rows, cols = padded_scale.shape
        assert rows % 128 == 0
        assert cols % 4 == 0
        padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                            cols // 4, 4)
        swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
        swizzled_scale = swizzled_scale.contiguous().cuda()
        return (swizzled_scale.reshape(M, K)
                if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

    def process_weights_after_loading(self, layer: Module) -> None:

        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

        layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                                requires_grad=False)

        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
        assert (layer.weight_scale.shape[1] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Block scale must be represented as FP8-E4M3")
        swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)

        layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                                requires_grad=False)
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        if self.use_marlin:
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
            del layer.weight_scale_swizzled

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.use_marlin:
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias)

        output_dtype = x.dtype
        output_shape = [x.shape[0], layer.weight.shape[0]]

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
        s_quant = 1 / layer.input_scale
        x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
        assert (x_fp4.dtype == torch.uint8)
        assert (layer.weight.dtype == torch.uint8)
        assert (x_blockscale.dtype == torch.float8_e4m3fn)
        assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
        assert (layer.alpha.dtype == torch.float32)

        out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                    layer.weight_scale_swizzled, layer.alpha,
                                    output_dtype)
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_fp4_supported()

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = False

__init__

__init__(quant_config: ModelOptNvFp4Config)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptNvFp4Config):
    self.quant_config = quant_config
    self.cutlass_nvfp4_supported = cutlass_fp4_supported()
    self.use_marlin = False

    if not self.cutlass_nvfp4_supported:
        if is_fp4_marlin_supported():
            self.use_marlin = True
        else:
            raise ValueError("Current platform does not support NVFP4"
                             " quantization. Please use Blackwell and"
                             " above.")

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if self.use_marlin:
        return apply_fp4_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias)

    output_dtype = x.dtype
    output_shape = [x.shape[0], layer.weight.shape[0]]

    # quantize BF16 or FP16 to (FP4 and interleaved block scale)
    s_quant = 1 / layer.input_scale
    x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

    # validate dtypes of quantized input, input block scale,
    # weight and weight_blockscale
    assert (x_fp4.dtype == torch.uint8)
    assert (layer.weight.dtype == torch.uint8)
    assert (x_blockscale.dtype == torch.float8_e4m3fn)
    assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
    assert (layer.alpha.dtype == torch.float32)

    out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                layer.weight_scale_swizzled, layer.alpha,
                                output_dtype)
    if bias is not None:
        out = out + bias
    return out.view(*output_shape)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition

    if (input_size_per_partition % 16 != 0):
        raise ValueError("Unsupported model when in features size is "
                         "not multiple of 16")
    # The nvfp4 weight is still represented as
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_nvfp4_serialized
                    else params_dtype)
    # Weight
    weight = ModelWeightParameter(
        data=torch.empty(
            # 2 fp4 items are packed in the input dimension
            layer.output_size_per_partition,
            layer.input_size_per_partition // 2,
            dtype=torch.uint8),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # Input Weight Scale
    input_scale = PerTensorScaleParameter(data=torch.empty(
        len(output_partition_sizes), dtype=torch.float32),
                                          weight_loader=weight_loader)
    layer.register_parameter("input_scale", input_scale)

    # Global Weight Scale
    weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
        len(output_partition_sizes), dtype=torch.float32),
                                             weight_loader=weight_loader)
    layer.register_parameter("weight_scale_2", weight_scale_2)

    # Per Block Weight Scale
    weight_scale = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition // self.quant_config.group_size,
        dtype=weight_dtype,
    ),
                                        input_dim=1,
                                        output_dim=0,
                                        weight_loader=weight_loader)

    layer.register_parameter("weight_scale", weight_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:

    # global scales:
    input_scale_2 = layer.input_scale.max().to(torch.float32)
    layer.input_scale = Parameter(input_scale_2, requires_grad=False)

    weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
    layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

    layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                            requires_grad=False)

    # Swizzle the weight blockscale.
    # contracting dimension is input dimension
    # block_size = 16;
    assert (layer.weight_scale.shape[1] % 16 == 0), (
        "Expected weight_scale.dim(1) to be divisible by 16")
    assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
        "Weight Block scale must be represented as FP8-E4M3")
    swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)

    layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                            requires_grad=False)
    layer.weight = Parameter(layer.weight.data, requires_grad=False)

    if self.use_marlin:
        prepare_fp4_layer_for_marlin(layer)
        del layer.alpha
        del layer.input_scale
        del layer.weight_scale_swizzled

swizzle_blockscale

swizzle_blockscale(scale: tensor)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def swizzle_blockscale(self, scale: torch.tensor):
    assert (scale.dtype == torch.float8_e4m3fn)
    # Pad and blockwise interleave weight_scale
    scale_ndim = scale.ndim
    if scale.ndim == 2:
        scale = scale.unsqueeze(0)
    assert scale.ndim == 3
    B, M, K = scale.shape
    round_up_multiple = lambda x, m: (x + m - 1) // m * m
    M_padded = round_up_multiple(M, 128)
    K_padded = round_up_multiple(K, 4)
    padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
    padded_scale[:B, :M, :K] = scale
    batches, rows, cols = padded_scale.shape
    assert rows % 128 == 0
    assert cols % 4 == 0
    padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                        cols // 4, 4)
    swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
    swizzled_scale = swizzled_scale.contiguous().cuda()
    return (swizzled_scale.reshape(M, K)
            if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

cutlass_fp4_supported

cutlass_fp4_supported() -> bool
Source code in vllm/model_executor/layers/quantization/modelopt.py
def cutlass_fp4_supported() -> bool:
    if not current_platform.is_cuda():
        return False
    capability_tuple = current_platform.get_device_capability()
    capability = -1 if capability_tuple is None else capability_tuple.to_int()
    return cutlass_scaled_mm_supports_fp4(capability)