乐于分享
好东西不私藏

TensorRT-LLM 0.5.0 源码之十三

TensorRT-LLM 0.5.0 源码之十三

QuantMode

from enum import IntFlag, auto 这行代码让你能够创建支持位运算(如按位或 |、按位与 &)的枚举类型,并能自动为成员分配值。这在管理一组可以组合使用的选项时(例如权限、状态标志)特别有用。

下面的表格总结了 IntFlag 和 auto 的核心特性和典型应用场景:

特性
说明
典型应用场景
位标志组合
使用 `
组合多个标志,用

&` 检查是否包含某标志。
自动赋值 auto()

 自动分配递增的2的幂次方值(1, 2, 4, 8…)。
简化枚举定义,避免手动计算
整数子类
成员也是整数,可在需要整数的场合直接使用。
与期望整数值的旧代码或API交互

基本定义与赋值

使用 IntFlag 和 auto() 可以非常清晰地定义一组标志。auto() 会自动分配 2 的幂次方值(1, 2, 4, 8…),这确保了每个标志对应一个独立的二进制位,是进行位运算的基础 。

from enum import IntFlag, autoclass Permissions(IntFlag):    READ = auto()    # 自动赋值为 1 (二进制 0b001)    WRITE = auto()   # 自动赋值为 2 (二进制 0b010)    EXECUTE = auto() # 自动赋值为 4 (二进制 0b100)

组合与检查标志

定义好标志后,你可以使用位运算符来组合和检查它们。

# 组合标志:使用按位或运算符 "|"user_permissions = Permissions.READ | Permissions.WRITEprint(user_permissions)  # 输出: Permissions.READ|WRITEprint(int(user_permissions))  # 输出: 3 (因为 1 | 2 = 3)# 检查标志:使用按位与运算符 "&" 或 "in" 关键字if user_permissions & Permissions.READ:  # 方式1:使用 "&"    print("具有读权限")if Permissions.READ in user_permissions:  # 方式2:使用 "in",更直观    print("具有读权限")# 检查是否同时拥有多个权限if (user_permissions & (Permissions.READ | Permissions.WRITE)) == (Permissions.READ | Permissions.WRITE):    print("同时具有读和写权限")

实用技巧与场景

  • • 自定义起始值:虽然 auto() 通常从1开始,但你可以通过定义 _generate_next_value_ 方法来自定义赋值逻辑。例如,让值从其他数字开始,或者使用成员名称的哈希值等 。
  • • 零值与无标志:可以显式定义一个值为 0 的成员,表示“没有任何标志”。在布尔上下文中,值为 0 的成员会被视为 False 。

    class Color(Flag):    BLACK = 0    RED = auto()    BLUE = auto()print(bool(Color.BLACK))  # 输出: False
  • • 网络协议:在定义 TCP 标志等场景中非常实用 。

    class TCPFlags(IntFlag):    SYN = auto()    ACK = auto()    FIN = auto()# 表示一个 SYN-ACK 包packet_flags = TCPFlags.SYN | TCPFlags.ACK

注意事项

  • • IntFlag 与 Flag 的区别:还有一个基本的 Flag 类。与 IntFlag 的主要区别在于,Flag 的成员不是整数子类,因此不能直接与整数进行比较或运算,这提供了更强的类型隔离。如果你不需要和整数混用,使用 Flag 可能更安全 。
  • • 确保值唯一:如果枚举值需要全局唯一,可以使用 @unique 装饰器,这样如果有重复值(包括别名)会在定义时抛出异常 。
from enum import IntFlag, autoclass QuantMode(IntFlag):    # [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/common/quantization.h    # The weights are quantized to 4 bits.    INT4_WEIGHTS = auto()    # The weights are quantized to 8 bits.    INT8_WEIGHTS = auto()    # The activations are quantized.    ACTIVATIONS = auto()    # The method uses one scaling factor per channel. It's pre-computed (static) from the weights.    PER_CHANNEL = auto()    # The method uses one scaling factor per token. It's computed on-the-fly.    PER_TOKEN = auto()    # The method uses one scaling factor per group. It's pre-computed (static) from the weights.    PER_GROUP = auto()    # The KV cache is quantized in INT8.    INT8_KV_CACHE = auto()    # The KV cache is quantized in FP8.    FP8_KV_CACHE = auto()    # FP8 QDQ    FP8_QDQ = auto()    # The smallest power-of-two that is not used by a flag. Do not call auto() after that line.    COUNT = auto()    # Bitmask to detect if weights, activations or both are quantized.    WEIGHTS_AND_ACTIVATIONS = INT4_WEIGHTS | INT8_WEIGHTS | ACTIVATIONS    # The mask of all valid flags.    VALID_FLAGS = COUNT - 1    # All the bits set? You can restrict the test to the bits indicated by "mask".    def _all(self, bits, mask=VALID_FLAGS):        return (self & mask) == bits    # Is one of the bits of the mask set?    def _any(self, bits):        return (self & bits) != 0    def is_int8_weight_only(self):        return self._all(self.INT8_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS)    def is_int4_weight_only(self):        return self._all(self.INT4_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS)    def is_weight_only(self):        return self.is_int4_weight_only() or self.is_int8_weight_only()    def is_int4_weight_only_per_group(self):        return self.is_int4_weight_only() and self._any(self.PER_GROUP)    def has_act_and_weight_quant(self):        return self._all(self.INT8_WEIGHTS | self.ACTIVATIONS,                         self.WEIGHTS_AND_ACTIVATIONS)    def has_per_token_dynamic_scaling(self):        return self._any(self.PER_TOKEN)    def has_act_static_scaling(self):        return not self.has_per_token_dynamic_scaling()    def has_per_channel_scaling(self):        return self._any(self.PER_CHANNEL)    def has_per_group_scaling(self):        return self._any(self.PER_GROUP)    def has_int8_kv_cache(self):        return self._any(self.INT8_KV_CACHE)    def has_fp8_kv_cache(self):        return self._any(self.FP8_KV_CACHE)    def has_kv_cache_quant(self):        return self.has_int8_kv_cache() or self.has_fp8_kv_cache()    def has_fp8_qdq(self):        return self._any(self.FP8_QDQ)    def has_any_quant(self):        return self._any(self.INT8_WEIGHTS | self.ACTIVATIONS                         | self.INT8_KV_CACHE | self.FP8_KV_CACHE                         | self.FP8_QDQ)    def set_int8_kv_cache(self):        return self | self.INT8_KV_CACHE    def set_fp8_kv_cache(self):        return self | self.FP8_KV_CACHE    def set_fp8_qdq(self):        return self | self.FP8_QDQ    @staticmethod    def from_description(quantize_weights=False,                         quantize_activations=False,                         per_token=False,                         per_channel=False,                         per_group=False,                         use_int4_weights=False,                         use_int8_kv_cache=False,                         use_fp8_kv_cache=False,                         use_fp8_qdq=False):        def raise_error():            raise ValueError(f"Unsupported combination of QuantMode args: "                             f"{quantize_weights=}, "                             f"{quantize_activations=}, "                             f"{per_token=}, "                             f"{per_channel=}, "                             f"{per_group=}, "                             f"{use_int4_weights=}"                             f"{use_int8_kv_cache=}"                             f"{use_fp8_kv_cache=}"                             f"{use_fp8_qdq=}")        # We must quantize weights when we quantize activations.        if quantize_activations and not quantize_weights:            raise_error()        # If we set per_token or per_channel, we must quantize both weights and activations.        if (per_token or per_channel) and not (quantize_weights                                               and quantize_activations):            raise_error()        mode = QuantMode(0)        # Do we quantize the weights - if so, do we use INT4 or INT8?        if quantize_weights and use_int4_weights:            mode = mode | QuantMode.INT4_WEIGHTS        elif quantize_weights:            mode = mode | QuantMode.INT8_WEIGHTS        # Do we quantize the activations?        if quantize_activations:            mode = mode | QuantMode.ACTIVATIONS        # Per-channel/per-token/per-group additional flags.        if per_channel:            mode = mode | QuantMode.PER_CHANNEL        if per_token:            mode = mode | QuantMode.PER_TOKEN        if per_group:            mode = mode | QuantMode.PER_GROUP        # Int8 KV cache        if use_int8_kv_cache:            mode = mode | QuantMode.INT8_KV_CACHE        # FP8 KV cache        if use_fp8_kv_cache:            mode = mode | QuantMode.FP8_KV_CACHE        if use_fp8_qdq:            mode = mode | QuantMode.FP8_QDQ        return mode    @staticmethod    def use_smooth_quant(per_token=False, per_channel=False):        return QuantMode.from_description(True, True, per_token, per_channel)    @staticmethod    def use_weight_only(use_int4_weights=False):        return QuantMode.from_description(quantize_weights=True,                                          quantize_activations=False,                                          per_token=False,                                          per_channel=False,                                          per_group=False,                                          use_int4_weights=use_int4_weights)

Quantize

class Quantize(Module):    """        Quantize Layer        For per-tensor mode, the scaling factor is a scalar.        For per-channel mode, the scaling factor is a vector.        """    def __init__(        self,        output_dtype: str = 'int8',        scaling_factor_dtype: str = 'float32',        in_channels: int = -1,        axis=-1,) -> None:        super().__init__()        self.scaling_factor = Parameter(shape=(in_channels, ) if axis != -1 else                                        (),                                        dtype=scaling_factor_dtype)        self.output_dtype = output_dtype        self.axis = axis    def forward(self, x):        return quantize(x, self.scaling_factor.value, self.output_dtype,                        self.axis)
def quantize(input: Tensor,             scale_factor: Tensor,             dtype: str,             axis: int = -1) -> Tensor:    layer = default_trtnet().add_quantize(input.trt_tensor,                                          scale_factor.trt_tensor,                                          str_dtype_to_trt(dtype))    layer.axis = axis    output = _create_tensor(layer.get_output(0), layer)    if not default_net().strongly_typed:        layer.get_output(0).dtype = str_dtype_to_trt(dtype)    return output

QuantizePerToken

class QuantizePerToken(Module):    """        Quantize Per Token and compute dynamic scales for SmoothQuant        """    def forward(self, x):        return quantize_per_token(x)
def quantize_per_token(x: Tensor) -> Tuple[Tensor]:    if not default_net().plugin_config.quantize_per_token_plugin:        if x.dtype != trt.float32:            x = cast(x, 'float32')        xmax = x.abs().max(-1, keepdim=True)        scale = xmax / 127.0        out = x * 127.0 / xmax        out = round(out)        out = clip(out, -128, 127)        quantized_out = cast(out, 'int8')        return quantized_out, scale    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        pfc = trt.PluginFieldCollection([])        quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin",                                                  pfc)        plug_inputs = [x.trt_tensor]        layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)        layer.get_output(0).set_dynamic_range(-127, 127)        quantized = _create_tensor(layer.get_output(0), layer)        quantized.trt_tensor.dtype = str_dtype_to_trt("int8")        scales = _create_tensor(layer.get_output(1), layer)        scales.trt_tensor.dtype = str_dtype_to_trt("float32")        return quantized, scales

Dequantize

class Dequantize(Module):    """        Dequantize Layer.        """    def __init__(self, axis: int = -1) -> None:        super().__init__()        self.scaling_factor = Parameter(shape=())        self.axis = axis    def forward(self, input):        return dequantize(input, self.scaling_factor.value, self.axis)
def dequantize(input: Tensor,               scale_factor: Tensor,               axis: int = -1,               output_type: Union[str, trt.DataType] = 'float16') -> Tensor:    if isinstance(output_type, str):        output_type = str_dtype_to_trt(output_type)    layer = default_trtnet().add_dequantize(input.trt_tensor,                                            scale_factor.trt_tensor,                                            output_type)    layer.axis = axis    if not default_net().strongly_typed:        layer.precision = input.dtype    output = _create_tensor(layer.get_output(0), layer)    return output

SmoothQuantColumnLinear

在这里插入图片描述
在这里插入图片描述
class SmoothQuantLinear(Module):    def __init__(self,                 in_features,                 out_features,                 bias=True,                 dtype=None,                 tp_group=None,                 tp_size=1,                 gather_output=True,                 quant_mode=QuantMode(0)):        super().__init__()        self.in_features = in_features        self.out_features = out_features // tp_size        if not quant_mode.has_act_and_weight_quant():            raise ValueError(                "SmoothQuant Linear has to have act+weight quantization mode set"            )        weights_dtype = dtype        # Dirty hack to make it work with SmoothQuant int8 weights        # reinterpreted as fp32 weights due to the int8 TRT plugin limitation.        if quant_mode.has_act_and_weight_quant():            assert self.in_features % 4 == 0            self.in_features = self.in_features // 4            weights_dtype = "float32"        self.weight = Parameter(shape=(self.out_features, self.in_features),                                dtype=weights_dtype)        if quant_mode.has_act_and_weight_quant():            scale_shape = (1, self.out_features                           ) if quant_mode.has_per_channel_scaling() else (1, 1)            self.per_channel_scale = Parameter(shape=scale_shape,                                               dtype="float32")        if quant_mode.has_act_static_scaling():            self.act_scale = Parameter(shape=(1, 1), dtype="float32")        self.tp_size = tp_size        self.tp_group = tp_group        self.gather_output = gather_output        self.quant_mode = quant_mode        if bias:            self.bias = Parameter(shape=(self.out_features, ), dtype=dtype)        else:            self.register_parameter('bias', None)    def forward(self, x):        if self.quant_mode.has_act_static_scaling():            per_token_scale = self.act_scale.value        else:            # If we are in SmoothQuant with dynamic activation scaling,            # input x has to be a tuple of int8 tensor and fp32 scaling factors            x, per_token_scale = x        x = smooth_quant_gemm(x, self.weight.value, per_token_scale,                              self.per_channel_scale.value,                              self.quant_mode.has_per_token_dynamic_scaling(),                              self.quant_mode.has_per_channel_scaling())        if self.bias is not None:            x = x + self.bias.value        if self.gather_output and self.tp_size > 1 and self.tp_group is not None:            # 1. [dim0, local_dim] -> [dim0 * tp_size, local_dim]            x = allgather(x, self.tp_group)            # 2. [dim0 * tp_size, local_dim] -> [dim0, local_dim * tp_size]            # 2.1 split            split_size = shape(x, dim=0) / self.tp_size            ndim = x.ndim()            starts = [constant(int32_array([0])) for _ in range(ndim)]            sizes = [shape(x, dim=d) for d in range(ndim)]            sizes[0] = split_size            sections = []            for i in range(self.tp_size):                starts[0] = split_size * i                sections.append(slice(x, concat(starts), concat(sizes)))            # 2.2 concat            x = concat(sections, dim=1)        return xSmoothQuantColumnLinear = SmoothQuantLinear
def smooth_quant_gemm(input: Tensor, weights: Tensor, scales_a: Tensor,                      scales_b: Tensor, per_token_scaling: bool,                      per_channel_scaling: bool) -> Tensor:    if not default_net().plugin_config.smooth_quant_gemm_plugin:        raise TypeError("Smooth Quant GEMM is only supported with plugin")    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'SmoothQuantGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        per_channel_scaling = 1 if per_channel_scaling else 0        per_channel_scaling = trt.PluginField(            "has_per_channel_scaling",            np.array(per_channel_scaling, dtype=np.int32),            trt.PluginFieldType.INT32)        per_token_scaling = 1 if per_token_scaling else 0        per_token_scaling = trt.PluginField(            "has_per_token_scaling", np.array(per_token_scaling,                                              dtype=np.int32),            trt.PluginFieldType.INT32)        p_dtype = default_net().plugin_config.smooth_quant_gemm_plugin        pf_type = trt.PluginField(            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),            trt.PluginFieldType.INT32)        pfc = trt.PluginFieldCollection(            [per_channel_scaling, per_token_scaling, pf_type])        gemm_plug = plg_creator.create_plugin("sq_gemm", pfc)        plug_inputs = [            input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor,            scales_b.trt_tensor        ]        layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)        layer.get_input(0).set_dynamic_range(-127, 127)        return _create_tensor(layer.get_output(0), layer)

SmoothQuantRowLinear

class SmoothQuantRowLinear(Module):    def __init__(self,                 in_features,                 out_features,                 bias=True,                 dtype=None,                 tp_group=None,                 tp_size=1,                 quant_mode=QuantMode(0)):        super().__init__()        self.in_features = in_features // tp_size        self.out_features = out_features        if not quant_mode.has_act_and_weight_quant():            raise ValueError(                "SmoothQuant Linear has to have act+weight quantization mode set"            )        weights_dtype = dtype        # Dirty hack to make it work with SmoothQuant int8 weights        # reinterpreted as fp32 weights due to the int8 TRT plugin limitation.        if quant_mode.has_act_and_weight_quant():            assert self.in_features % 4 == 0            self.in_features = self.in_features // 4            weights_dtype = "float32"        self.weight = Parameter(shape=(self.out_features, self.in_features),                                dtype=weights_dtype)        self.smoother = Parameter(shape=(1, self.in_features * 4),                                  dtype="float32")        if quant_mode.has_act_and_weight_quant():            scale_shape = (1, self.out_features                           ) if quant_mode.has_per_channel_scaling() else (1, 1)            self.per_channel_scale = Parameter(shape=scale_shape,                                               dtype="float32")        if quant_mode.has_act_static_scaling():            self.act_scale = Parameter(shape=(1, 1), dtype="float32")        if bias:            self.bias = Parameter(shape=(self.out_features, ), dtype=dtype)        else:            self.register_parameter('bias', None)        self.tp_group = tp_group        self.tp_size = tp_size        self.quant_mode = quant_mode    def forward(self, x, workspace=None):        if self.quant_mode.has_act_static_scaling():            per_token_scale = self.act_scale.value        else:            x, per_token_scale = x        x = smooth_quant_gemm(x, self.weight.value, per_token_scale,                              self.per_channel_scale.value,                              self.quant_mode.has_per_token_dynamic_scaling(),                              self.quant_mode.has_per_channel_scaling())        if self.tp_size > 1 and self.tp_group is not None:            x = allreduce(x, self.tp_group, workspace)        if self.bias is not None:            x = x + self.bias.value        return x

SmoothQuantLayerNorm

class SmoothQuantLayerNorm(Module):    def __init__(self,                 normalized_shape,                 eps=1e-05,                 elementwise_affine=True,                 dtype=None,                 quant_mode=QuantMode(0)):        super().__init__()        if isinstance(normalized_shape, int):            normalized_shape = (normalized_shape, )        if not quant_mode.has_act_and_weight_quant():            raise ValueError(                "SmoothQuant layer norm has to have some quantization mode set")        self.normalized_shape = tuple(normalized_shape)        self.elementwise_affine = elementwise_affine        if self.elementwise_affine:            self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)            self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)        else:            self.register_parameter('weight', None)            self.register_parameter('bias', None)        self.eps = eps        self.quant_mode = quant_mode        if self.quant_mode.has_act_and_weight_quant():            self.scale_to_int = Parameter(shape=(1, ), dtype=dtype)        else:            self.register_parameter('scale_to_int', None)    def forward(self, x):        weight = None if self.weight is None else self.weight.value        bias = None if self.bias is None else self.bias.value        scale = None if self.scale_to_int is None else self.scale_to_int.value        return smooth_quant_layer_norm(            x,            self.normalized_shape,            weight,            bias,            scale,            self.eps,            dynamic_act_scaling=self.quant_mode.has_per_token_dynamic_scaling())
def smooth_quant_layer_norm(input: Tensor,                            normalized_shape: Union[int, Tuple[int]],                            weight: Optional[Tensor] = None,                            bias: Optional[Tensor] = None,                            scale: Optional[Tensor] = None,                            eps: float = 1e-05,                            use_diff_of_squares: bool = True,                            dynamic_act_scaling: bool = False) -> Tensor:    if not default_net().plugin_config.layernorm_quantization_plugin:        raise TypeError("Smooth Quant Layer Norm is only supported with plugin")    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'LayernormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),                              trt.PluginFieldType.FLOAT32)        use_diff_of_squares = trt.PluginField(            "use_diff_of_squares",            np.array([int(use_diff_of_squares)], dtype=np.int32),            trt.PluginFieldType.INT32)        dyn_act_scaling = trt.PluginField(            "dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),            trt.PluginFieldType.INT32)        p_dtype = default_net().plugin_config.layernorm_quantization_plugin        pf_type = trt.PluginField(            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),            trt.PluginFieldType.INT32)        pfc = trt.PluginFieldCollection(            [eps, use_diff_of_squares, dyn_act_scaling, pf_type])        layernorm_plug = plg_creator.create_plugin("layernorm_quantized", pfc)        normalized_shape = [normalized_shape] if isinstance(            normalized_shape, int) else normalized_shape        if weight is None:            weight = constant(                np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))        if bias is None:            bias = constant(                np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))        plug_inputs = [            input.trt_tensor, weight.trt_tensor, bias.trt_tensor,            scale.trt_tensor        ]        layer = default_trtnet().add_plugin_v2(plug_inputs, layernorm_plug)        layer.get_output(0).set_dynamic_range(-127, 127)        if not dynamic_act_scaling:            return _create_tensor(layer.get_output(0), layer)        return _create_tensor(layer.get_output(0),                              layer), _create_tensor(layer.get_output(1), layer)
在这里插入图片描述

SmoothQuantRmsNorm

class SmoothQuantRmsNorm(Module):    def __init__(self,                 normalized_shape,                 eps=1e-06,                 elementwise_affine=True,                 dtype=None,                 quant_mode=QuantMode(0),                 bias=False):        super().__init__()        if isinstance(normalized_shape, int):            normalized_shape = (normalized_shape, )        if not quant_mode.has_act_and_weight_quant():            raise ValueError(                "SmoothQuant Rms norm has to have some quantization mode set")        self.normalized_shape = tuple(normalized_shape)        self.elementwise_affine = elementwise_affine        if self.elementwise_affine:            self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)        else:            self.register_parameter('weight', None)        if bias:            self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)        else:            self.register_parameter('bias', None)        self.eps = eps        self.quant_mode = quant_mode        if self.quant_mode.has_act_and_weight_quant():            self.scale_to_int = Parameter(shape=(1, ), dtype=dtype)        else:            self.register_parameter('scale_to_int', None)    def forward(self, x):        weight = None if self.weight is None else self.weight.value        bias = None if self.bias is None else self.bias.value        scale = None if self.scale_to_int is None else self.scale_to_int.value        return smooth_quant_rms_norm(            x,            self.normalized_shape,            weight,            bias,            scale,            self.eps,            dynamic_act_scaling=self.quant_mode.has_per_token_dynamic_scaling())
def smooth_quant_rms_norm(input: Tensor,                          normalized_shape: Union[int, Tuple[int]],                          weight: Optional[Tensor] = None,                          bias: Optional[Tensor] = None,                          scale: Optional[Tensor] = None,                          eps: float = 1e-05,                          dynamic_act_scaling: bool = False) -> Tensor:    if not default_net().plugin_config.rmsnorm_quantization_plugin:        raise TypeError("Smooth Quant Rms Norm is only supported with plugin")    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'RmsnormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),                              trt.PluginFieldType.FLOAT32)        dyn_act_scaling = trt.PluginField(            "dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),            trt.PluginFieldType.INT32)        p_dtype = default_net().plugin_config.rmsnorm_quantization_plugin        pf_type = trt.PluginField(            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),            trt.PluginFieldType.INT32)        pfc = trt.PluginFieldCollection([eps, dyn_act_scaling, pf_type])        rmsnorm_plug = plg_creator.create_plugin("rmsnorm_quantized", pfc)        normalized_shape = [normalized_shape] if isinstance(            normalized_shape, int) else normalized_shape        if weight is None:            weight = constant(                np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))        if bias is None:            bias = constant(                np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))        plug_inputs = [            input.trt_tensor, weight.trt_tensor, bias.trt_tensor,            scale.trt_tensor        ]        layer = default_trtnet().add_plugin_v2(plug_inputs, rmsnorm_plug)        layer.get_output(0).set_dynamic_range(-127, 127)        if not dynamic_act_scaling:            return _create_tensor(layer.get_output(0), layer)        return _create_tensor(layer.get_output(0),                              layer), _create_tensor(layer.get_output(1), layer)

WeightOnlyQuantColumnLinear

class WeightOnlyQuantLinear(Module):    def __init__(self,                 in_features,                 out_features,                 bias=True,                 dtype=None,                 tp_group=None,                 tp_size=1,                 gather_output=True,                 quant_mode=QuantMode.use_weight_only()):        super().__init__()        if quant_mode.is_int8_weight_only():            self.weight_only_quant_mode = 1            quant_type_size_in_bits = 8        elif quant_mode.is_int4_weight_only():            self.weight_only_quant_mode = 2            quant_type_size_in_bits = 4        self.in_features = in_features        self.out_features = out_features // tp_size        # we use a fake tensor with data_type = float        self.weight = Parameter(shape=(self.in_features,                                       int(self.out_features *                                           quant_type_size_in_bits / 32)),                                dtype="float32")        scale_shape = (self.out_features, )        self.per_channel_scale = Parameter(shape=scale_shape, dtype=dtype)        self.tp_size = tp_size        self.tp_group = tp_group        self.gather_output = gather_output        if bias:            self.bias = Parameter(shape=(self.out_features, ), dtype=dtype)        else:            self.register_parameter('bias', None)    def forward(self, x):        x = weight_only_quant_matmul(x, self.weight.value,                                     self.per_channel_scale.value,                                     self.weight_only_quant_mode)        if self.bias is not None:            x = x + self.bias.value        if self.gather_output and self.tp_size > 1 and self.tp_group is not None:            # 1. [dim0, local_dim] -> [dim0 * tp_size, local_dim]            x = allgather(x, self.tp_group)            # 2. [dim0 * tp_size, local_dim] -> [dim0, local_dim * tp_size]            # 2.1 split            split_size = shape(x, dim=0) / self.tp_size            ndim = x.ndim()            starts = [constant(int32_array([0])) for _ in range(ndim)]            sizes = [shape(x, dim=d) for d in range(ndim)]            sizes[0] = split_size            sections = []            for i in range(self.tp_size):                starts[0] = split_size * i                sections.append(slice(x, concat(starts), concat(sizes)))            # 2.2 concat            x = concat(sections, dim=1)        return xWeightOnlyQuantColumnLinear = WeightOnlyQuantLinear
def weight_only_quant_matmul(input: Tensor, weights: Tensor, scales: Tensor,                             weightTypeId: int) -> Tensor:    if not default_net().plugin_config.weight_only_quant_matmul_plugin:        raise TypeError(            "Weight Only Qunat MatMul is only supported with plugin")    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'WeightOnlyQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        weight_type_id = trt.PluginField("weight_type_id",                                         np.array(weightTypeId, dtype=np.int32),                                         trt.PluginFieldType.INT32)        p_dtype = default_net().plugin_config.weight_only_quant_matmul_plugin        pf_type = trt.PluginField(            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),            trt.PluginFieldType.INT32)        pfc = trt.PluginFieldCollection([pf_type, weight_type_id])        matmul_plug = plg_creator.create_plugin("woq_matmul", pfc)        plug_inputs = [input.trt_tensor, weights.trt_tensor, scales.trt_tensor]        layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)        return _create_tensor(layer.get_output(0), layer)

WeightOnlyQuantRowLinear

class WeightOnlyQuantRowLinear(Module):    def __init__(self,                 in_features,                 out_features,                 bias=True,                 dtype=None,                 tp_group=None,                 tp_size=1,                 quant_mode=QuantMode.use_weight_only()):        super().__init__()        if quant_mode.is_int8_weight_only():            self.weight_only_quant_mode = 1        elif quant_mode.is_int4_weight_only():            self.weight_only_quant_mode = 2        self.in_features = in_features // tp_size        self.out_features = out_features        #we use a fake tensor with data_type = float        self.weight = Parameter(shape=(self.in_features,                                       int(self.out_features / 4 /                                           self.weight_only_quant_mode)),                                dtype="float32")        self.per_channel_scale = Parameter(shape=(self.out_features, ),                                           dtype=dtype)        if bias:            self.bias = Parameter(shape=(self.out_features, ), dtype=dtype)        else:            self.register_parameter('bias', None)        self.tp_group = tp_group        self.tp_size = tp_size    def forward(self, x, workspace=None):        x = weight_only_quant_matmul(x, self.weight.value,                                     self.per_channel_scale.value,                                     self.weight_only_quant_mode)        if self.tp_size > 1 and self.tp_group is not None:            x = allreduce(x, self.tp_group, workspace)        if self.bias is not None:            x = x + self.bias.value        return x

WeightOnlyGroupwiseQuantColumnLinear

class WeightOnlyGroupwiseQuantLinear(Module):    def __init__(self,                 in_features,                 out_features,                 group_size=128,                 pre_quant_scale=False,                 zero=False,                 bias=False,                 dtype=None,                 tp_group=None,                 tp_size=1,                 gather_output=True):        super().__init__()        # Flags for indicating whether the corresponding inputs are applied in quant_algo        BIAS = 1        ZERO = 2        PRE_QUANT_SCALE = 4        self.quant_algo = pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS        self.group_size = group_size        self.in_features = in_features        self.out_features = out_features // tp_size        self.qweight = Parameter(shape=(self.in_features,                                        self.out_features // 8),                                 dtype="float32")        scale_shape = (self.in_features // group_size, self.out_features)        self.scale = Parameter(shape=scale_shape, dtype=dtype)        if pre_quant_scale:            self.pre_quant_scale = Parameter(shape=(1, self.in_features),                                             dtype=dtype)        else:            self.register_parameter('pre_quant_scale', None)        if zero:            self.zero = Parameter(shape=scale_shape, dtype=dtype)        else:            self.register_parameter('zero', None)        if bias:            self.bias = Parameter(shape=(self.out_features, ), dtype=dtype)        else:            self.register_parameter('bias', None)        self.tp_size = tp_size        self.tp_group = tp_group        self.gather_output = gather_output    def forward(self, x):        pre_quant_scale = self.pre_quant_scale.value if self.pre_quant_scale else None        zero = self.zero.value if self.zero else None        bias = self.bias.value if self.bias else None        x = weight_only_groupwise_quant_matmul(x, pre_quant_scale,                                               self.qweight.value,                                               self.scale.value, zero, bias,                                               self.quant_algo, self.group_size)        if self.gather_output and self.tp_size > 1 and self.tp_group is not None:            # 1. [dim0, local_dim] -> [dim0 * tp_size, local_dim]            x = allgather(x, self.tp_group)            # 2. [dim0 * tp_size, local_dim] -> [dim0, local_dim * tp_size]            # 2.1 split            split_size = shape(x, dim=0) / self.tp_size            ndim = x.ndim()            starts = [constant(int32_array([0])) for _ in range(ndim)]            sizes = [shape(x, dim=d) for d in range(ndim)]            sizes[0] = split_size            sections = []            for i in range(self.tp_size):                starts[0] = split_size * i                sections.append(slice(x, concat(starts), concat(sizes)))            # 2.2 concat            x = concat(sections, dim=1)        return xWeightOnlyGroupwiseQuantColumnLinear = WeightOnlyGroupwiseQuantLinear
def weight_only_groupwise_quant_matmul(input: Tensor, pre_quant_scale: Tensor,                                       weights: Tensor, scales: Tensor,                                       zeros: Tensor, biases: Tensor,                                       quant_algo: int,                                       group_size: int) -> Tensor:    if not default_net(    ).plugin_config.weight_only_groupwise_quant_matmul_plugin:        raise TypeError(            "Weight Only Groupwise Quant MatMul is only supported with plugin")    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'WeightOnlyGroupwiseQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        quant_algo_ = trt.PluginField("quant_algo",                                      np.array(quant_algo, dtype=np.int32),                                      trt.PluginFieldType.INT32)        group_size_ = trt.PluginField("group_size",                                      np.array(group_size, dtype=np.int32),                                      trt.PluginFieldType.INT32)        p_dtype = default_net(        ).plugin_config.weight_only_groupwise_quant_matmul_plugin        pf_type_ = trt.PluginField(            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),            trt.PluginFieldType.INT32)        pfc = trt.PluginFieldCollection([pf_type_, quant_algo_, group_size_])        matmul_plug = plg_creator.create_plugin("woq_groupwise_matmul", pfc)        # quant_algo = pre_quant_scale * 4 + zero * 2 + bias        plug_inputs = [input.trt_tensor]        # Flags for indicating whether the corresponding inputs are applied in quant_algo        # quant_algo = pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS        # Here pre_quant_scale, zero and bias are boolean type        BIAS = 1        ZERO = 2        PRE_QUANT_SCALE = 4        if quant_algo & PRE_QUANT_SCALE:            plug_inputs += [pre_quant_scale.trt_tensor]        plug_inputs += [weights.trt_tensor, scales.trt_tensor]        if quant_algo & ZERO:            plug_inputs += [zeros.trt_tensor]        if quant_algo & BIAS:            plug_inputs += [biases.trt_tensor]        layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)        return _create_tensor(layer.get_output(0), layer)

WeightOnlyGroupwiseQuantRowLinear

class WeightOnlyGroupwiseQuantRowLinear(Module):    def __init__(self,                 in_features,                 out_features,                 group_size=128,                 pre_quant_scale=False,                 zero=False,                 bias=False,                 dtype=None,                 tp_group=None,                 tp_size=1):        super().__init__()        # Flags for indicating whether the corresponding inputs are applied in quant_algo        BIAS = 1        ZERO = 2        PRE_QUANT_SCALE = 4        self.quant_algo = pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS        self.group_size = group_size        self.in_features = in_features // tp_size        self.out_features = out_features        self.qweight = Parameter(shape=(self.in_features,                                        self.out_features // 8),                                 dtype="float32")        scale_shape = (self.in_features // group_size, self.out_features)        self.scale = Parameter(shape=scale_shape, dtype=dtype)        if pre_quant_scale:            self.pre_quant_scale = Parameter(shape=(1, self.in_features),                                             dtype=dtype)        else:            self.register_parameter('pre_quant_scale', None)        if zero:            self.zero = Parameter(shape=scale_shape, dtype=dtype)        else:            self.register_parameter('zero', None)        if bias:            self.bias = Parameter(shape=(self.out_features, ), dtype=dtype)        else:            self.register_parameter('bias', None)        self.tp_size = tp_size        self.tp_group = tp_group    def forward(self, x, workspace=None):        pre_quant_scale = self.pre_quant_scale.value if self.pre_quant_scale else None        zero = self.zero.value if self.zero else None        bias = self.bias.value if self.bias else None        x = weight_only_groupwise_quant_matmul(x, pre_quant_scale,                                               self.qweight.value,                                               self.scale.value, zero, bias,                                               self.quant_algo, self.group_size)        if self.tp_size > 1 and self.tp_group is not None:            x = allreduce(x, self.tp_group, workspace)        return x

SmoothQuantMLP

class SmoothQuantMLP(Module):    def __init__(self,                 hidden_size,                 ffn_hidden_size,                 hidden_act,                 bias=True,                 dtype=None,                 tp_group=None,                 tp_size=1,                 quant_mode=QuantMode(0)):        super().__init__()        if hidden_act not in ACT2FN:            raise ValueError(                'unsupported activation function: {}'.format(hidden_act))        self.fc = SmoothQuantColumnLinear(hidden_size,                                          ffn_hidden_size,                                          bias=bias,                                          dtype=dtype,                                          tp_group=tp_group,                                          tp_size=tp_size,                                          gather_output=False,                                          quant_mode=quant_mode)        self.proj = SmoothQuantRowLinear(ffn_hidden_size,                                         hidden_size,                                         bias=bias,                                         dtype=dtype,                                         tp_group=tp_group,                                         tp_size=tp_size,                                         quant_mode=quant_mode)        self.hidden_act = hidden_act        self.quant_mode = quant_mode        if self.quant_mode.has_act_static_scaling():            self.quantization_scaling_factor = Parameter(shape=(1, ),                                                         dtype='float32')        else:            self.register_parameter('quantization_scaling_factor', None)    def forward(self, hidden_states, workspace=None):        inter = self.fc(hidden_states)        inter = ACT2FN[self.hidden_act](inter)        inter = inter / self.proj.smoother.value        if self.quant_mode.has_act_and_weight_quant():            if self.quant_mode.has_act_static_scaling():                # Avoid quantiztion layers as it breaks int8 plugins                inter = quantize_tensor(inter,                                        self.quantization_scaling_factor.value)            else:                # Quantize per token outputs tuple:                # quantized tensor and scaling factors per token                inter = quantize_per_token(inter)        output = self.proj(inter, workspace)        return output
def quantize_tensor(x, scale):    if not default_net().plugin_config.quantize_tensor_plugin:        scaled = x * scale        rounded = round(scaled)        clipped = clip(rounded, -128, 127)        quantized = cast(clipped, 'int8')    else:        plg_creator = trt.get_plugin_registry().get_plugin_creator(            'QuantizeTensor', '1', TRT_LLM_PLUGIN_NAMESPACE)        assert plg_creator is not None        pfc = trt.PluginFieldCollection([])        quantize_plug = plg_creator.create_plugin("quantize_tensor_plugin", pfc)        plug_inputs = [x.trt_tensor, scale.trt_tensor]        layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)        layer.get_output(0).set_dynamic_range(-127, 127)        quantized = _create_tensor(layer.get_output(0), layer)        quantized.trt_tensor.dtype = str_dtype_to_trt("int8")    return quantized

参考文献

  • • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/quantization/mode.py
  • • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/quantization/layers.py
  • • https://qiita.com/kyad/items/894bea24fdd0ed79318b
点个「赞」+「在看」❤️
让我们知道这份文字有温暖到你,也是我们持续创作的最大动力!
推荐
Lock-Free 队列实现原理
Share Memory 的 Bank Conflict
告别高成本!TensorRT-LLM实战:如何将LLM推理速度提升数倍
使用LoRA对LLM进行微调的实用技巧
强化学习小白必看:PTX Loss 到底是个啥?
GPT-5 Prompt Migration and Improvement Using the New Optimizer
Task 异步流 coroutine 实现
C++ corotine 介绍
搭建 VSCode 离线开发环境
nlohmann/json 库简介
Intro to C++ Coroutines: Concept
Hugging Face BPE Tokenizer 的资源文件
移动语义 std::move 和完美转发 std::forward
ACEBench: Who Wins the Match Point in Tool Usage?
什么是 GN
RULER: Relative Universal LLM-Elicited Rewards
SFT和RFT的区别
CosyVoice 3: 面向真实场景的大规模零样本语音生成模型
CosyVoice 3: Towards In-the-wild Speech Generation
语音合成(TTS)中文自然度:问题、成因、解决方案
上下文工程如何实现
上下文工程(Context Engineering)
新手必看!LangGraph 101:手把手教你搭一个深度研究 Agent
LangGraph 简介
SFT 泛化新解读:强化学习 + 奖励修正,一文读懂
程序员狂喜!Self-Instruct 框架全解析:无限生成高质量指令集,从此告别标注噩梦!
Evol-Instruct 竟能精准生成领域专属数据?实操技巧速看!
指令微调数据-少即是多
LLM generate 参数怎么用?
语音合成(TTS)跳跃与重复问题的解析:成因、机制及解决方案
大模型训练新思路:GEPA 靠 “反思” 赢过 RL,看完秒懂
F5-TTS:用 Flow Matching 玩转语音,流畅度和真实感都 “拉满” 了
E2 TTS:令人尴尬地简单、完全非自回归、零样本的语音合成技术
Voicebox:大规模文本引导的多语言通用语音生成技术
为什么都在聊 Kimi K2?Open Agentic Intelligence 藏着哪些新惊喜
Step-Audio-AQAA 端到端音频模型
DPO、PPO、GRPO的原理,区别与联系
OPENCSG 中文语料库:一系列高质量的中文数据集,用于语言模型训练
什么是 Classifier-Free Guidance?
Conditional Flow Matching : 连续标准流 Continuous Normalizing Flow
CFM 与 OT-CFM:条件流匹配与最优传输的碰撞
DPO损失实现
Conditional Flow Matching : 常微分方程ODE、欧拉方法和Neural ODE
当 Normalizing flow 遇上语音生成:AI 说话变 “真人” 的秘密在这里!
深度剖析:Kimi – Audio 中 BigVGAN 的神奇作用
为什么说分布变换是 Normalizing flow 的「灵魂操作」?
MATCHA-TTS 来了!条件流匹配让文本转语音效率飙升
从知识增长的角度提升RAG上下文的质量
MiniMax-Speech,零样本语音合成新突破,32 种语言轻松拿捏!
手把手教你创建 evol-instruct 数据集!附完整流程~
社交类聊天的 Query 分析与应答策略
SFT 中指令选择和响应选择哪个更重要?
角色扮演大模型技术分享2-超拟人模型的困境
最新!SpeechLLM 综述:架构、能力、挑战与未来全揭秘
如何低成本生成高质量指令微调数据?
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
Kimi-Audio:开源音频基础模型全面解析
Kimi-Audio 的 TTS 效果如何?
Qwen 的训练数据是怎么做的?
GeForce RTX 3090, 4090, A10, A40, A100, A800, L20, L40 显卡性能对比
如何低成本生成高质量指令微调数据?
掌握RAG:投入生产前要评估的8个场景
掌握RAG:如何评估RAG的LLM
掌握RAG:如何在部署后观察您的RAG
掌握RAG:如何选择嵌入模型
基础模型中的新范式:为什么o1是不同的,以及它将如何改变LLM应用
Semantic token和连续特征在SLLM下的对比
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
RLHF及其变体:进展和实际工程见解
Freeze-Omni: 低延迟语音对话模型
Fully Sharded Data Parallelism (FSDP)
什么是置信度?置信度模型怎么做?
晦涩难懂的 Flow matching!图形化理解
中文指令微调数据,质量就是一切!
基于 LLM 的文本泛化
CosyVoice 2:基于大型语言模型的可扩展流式语音合成技术
Mini-Omni2: with Vision, Speech and Duplex Capabilities
FSQ的原理与VQ-VAE的区别和联系
大模型并行训练的一些知识——极简版
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
要用 AI 裁员 50% 的千亿独角兽,公开认错,重启招聘!
single codebook和dual codebook在LLM中向量量化上有什么区别?
一些文档去重算法
最佳的指令数据应当是什么样的?
Prefill-Decode分离
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
Simhash-文档去重算法简介
RLHF 入门,高手勿进!
最佳的指令数据应当是什么样的?
CosyVoice:一种基于监督式语义标记的可扩展多语言 Zero-Shot 语音合成器
Model Context Protocol (MCP)
MCP(模型上下文协议)是什么以及它是如何运作的
压力测试LLMs——大海捞针实现