乐于分享
好东西不私藏

TensorRT-LLM 0.5.0 源码之十二

TensorRT-LLM 0.5.0 源码之十二

attention.py

class RotaryScalingType(IntEnum):
    none = 0
    linear = 1
    dynamic = 2


class
 PositionEmbeddingType(IntEnum):
    learned_absolute = 0
    rope_gptj = 1
    rope_gpt_neox = 2
    alibi = 3
    alibi_with_scale = 4
    relative = 5

    def
 is_rope(self) -> bool:
        return
 self in [self.rope_gptj, self.rope_gpt_neox]

    def
 is_alibi(self) -> bool:
        return
 self in [self.alibi, self.alibi_with_scale]

    @staticmethod

    def
 choices() -> List[str]:
        return
 [embedding.name for embedding in PositionEmbeddingType]


class
 AttentionMaskType(IntEnum):
    padding = 0
    causal = 1
    bidirectional = 2


class
 LayerNormType(IntEnum):
    LayerNorm = 0
    RmsNorm = 1
    GroupNorm = 2


class
 LayerNormPositionType(IntEnum):
    pre_layernorm = 0
    post_layernorm = 1

AttentionParams

class AttentionParams:

    def
 __init__(self,
                 sequence_length: Tensor = None,
                 context_lengths: Tensor = None,
                 host_context_lengths: Tensor = None,
                 max_context_length: int = None,
                 host_request_types: Tensor = None,
                 encoder_input_lengths: Tensor = None,
                 encoder_max_input_length: Tensor = None
):
        self
.sequence_length = sequence_length # 当前已生成的序列长度(Token数量)
        # 通常指每个序列的上下文长度。host_前缀可能表明该张量存放在主机内存而非设备上,用于需要CPU参与的逻辑

        self
.context_lengths = context_lengths
        self
.host_context_lengths = host_context_lengths
        # max allowed context length. Required to

        # compute scratch memory size.

        self
.max_context_length = max_context_length # 当前批次中最大的上下文长度,常用于计算临时内存大小。
        self
.host_request_types = host_request_types # 标识请求类型,例如判断是"生成"还是"编码"任务。

             # 这些参数明确用于交叉注意力,在编码器-解码器架构(如T5)中,它们表示编码器侧输入的长度信息

        self
.encoder_input_lengths = encoder_input_lengths
        self
.encoder_max_input_length = encoder_max_input_length

    def
 is_valid_cross_attn(self, do_cross_attention):
        if
 do_cross_attention:
            if
 self.encoder_input_lengths is None:
                return
 False
            if
 self.encoder_max_input_length is None:
                return
 False

    def
 is_valid(self, gpt_attention_plugin, remove_input_padding):
        if
 gpt_attention_plugin:
            if
 self.sequence_length is None:
                return
 False
            if
 self.context_lengths is None:
                return
 False
            if
 self.host_request_types is None:
                return
 False
            if
 self.max_context_length is None:
                return
 False

        if
 remove_input_padding:
            if
 self.host_context_lengths is None:
                return
 False
            if
 not gpt_attention_plugin:
                return
 False

        return
 True

KeyValueCacheParams

class KeyValueCacheParams:

    def
 __init__(self,
                 past_key_value: List[Tensor] = None,
                 host_past_key_value_lengths: Tensor = None,
                 kv_cache_block_pointers: List[Tensor] = None,
                 cache_indirection: Tensor = None,
                 past_key_value_length: Tensor = None
):
        # 核心缓存数据。通常是一个列表,每个元素对应模型某一层缓存的 Key 和 Value 张量(可能是拼接在一起的)

        self
.past_key_value = past_key_value
        # 存储在主机内存的每个序列当前缓存长度。用于让CPU知晓各序列进度,进行逻辑控制

        self
.host_past_key_value_lengths = host_past_key_value_lengths
        # 高级内存管理指针。用于类似 vLLM 中 PagedAttention的机制,将KV Cache存储在非连续的物理内存块中,通过指针数组来映射,极大减少内存碎片,提升吞吐量

        self
.kv_cache_block_pointers = kv_cache_block_pointers
        # 高级优化指标。在批量推理时,用于处理束搜索或动态序列管理,将逻辑序列位置映射到物理缓存位置

        self
.cache_indirection = cache_indirection
        # self.past_key_value_length = past_key_value_length


    def
 get_first_past_key_value(self):
        if
 self.past_key_value is None:
            return
 None
        return
 self.past_key_value[0]

    def
 get_first_kv_cache_block_pointers(self):
        if
 self.kv_cache_block_pointers is None:
            return
 None
        return
 self.kv_cache_block_pointers[0]

    def
 is_valid(self, gpt_attention_plugin):
        if
 gpt_attention_plugin:
            if
 self.host_past_key_value_lengths is None:
                return
 False
            if
 self.cache_indirection is None:
                return
 False

        return
 True

Attention

在这里插入图片描述
class Attention(Module):

    def
 __init__(self,
                 hidden_size,
                 num_attention_heads,
                 num_kv_heads=None,
                 max_position_embeddings=1024,
                 num_layers=1,
                 apply_query_key_layer_scaling=False,
                 attention_mask_type=AttentionMaskType.padding,
                 bias=True,
                 dtype=None,
                 position_embedding_type=PositionEmbeddingType.learned_absolute,
                 rotary_embedding_base=10000.0,
                 rotary_embedding_scaling=None,
                 use_int8_kv_cache=False,
                 rotary_embedding_percentage=1.0,
                 tp_group=None,
                 tp_size=1,
                 tp_rank=0,
                 multi_block_mode=False,
                 quant_mode: QuantMode = QuantMode(0),
                 q_scaling=1.0,
                 cross_attention=False,
                 relative_attention=False,
                 max_distance=0,
                 num_buckets=0,
                 instance_id: int = 0
):
        super
().__init__()

        self
.cross_attention = cross_attention
        self
.attention_mask_type = attention_mask_type
        self
.attention_head_size = hidden_size // num_attention_heads

        # 安head维度做的tp分解。

        # 将注意力头数均匀分配到每个设备。断言 num_attention_heads % tp_size == 0确保划分均匀,避免计算不一致

        assert
 num_attention_heads % tp_size == 0, \
        "num_attention_heads must be divisible by tp_size"

        self
.num_attention_heads = num_attention_heads // tp_size
        # 每个设备上的隐藏层维度相应减小。

        self
.hidden_size = hidden_size // tp_size
        # 如果键值头数(num_kv_heads)指定,则按 (num_kv_heads + tp_size - 1) // tp_size划分(向上取整);否则与查询头数相同。这允许使用比查询头更少的键值头(如GQA技术),减少计算量和内存占用

        self
.num_attention_kv_heads = (
            num_kv_heads + tp_size - 1
        ) // tp_size if num_kv_heads is not None else self.num_attention_heads

        self
.max_position_embeddings = max_position_embeddings
        self
.tp_size = tp_size
        self
.tp_rank = tp_rank

        self
.num_layers = num_layers
        self
.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        # 用于缩放点积注意力分数,防止梯度消失(公式:softmax(QK^T/d_k))

        self
.norm_factor = math.sqrt(self.attention_head_size)
        self
.q_scaling = q_scaling
        if
 self.apply_query_key_layer_scaling:
            # 缩放因子会乘以层数(self.norm_factor *= self.num_layers),这有助于深层网络的训练稳定性。

            self
.norm_factor *= self.num_layers
            self
.q_scaling *= self.num_layers
        # Whether to scale ALiBi bias. Mathematically, it's equivalent to

        # normalizing QK after adding bias.

        #   - False, inv_sqrt_Dh * Q*K^T + alibi_bias

        #   - True,  inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias

        self
.scale_alibi_bias = position_embedding_type == PositionEmbeddingType.alibi_with_scale

        self
.position_embedding_type = position_embedding_type
        self
.multi_block_mode = multi_block_mode

        self
.relative_attention = relative_attention
        self
.max_distance = max_distance

        self
.rotary_embedding_base = rotary_embedding_base
        self
.rotary_embedding_scale_type = RotaryScalingType.none
        self
.rotary_embedding_scale = 1.0
        if
 rotary_embedding_scaling is not None:
            assert
 rotary_embedding_scaling["type"] in ["linear", "dynamic"]
            self
.rotary_embedding_scale_type = RotaryScalingType.linear if rotary_embedding_scaling[
                "type"
] == "linear" else RotaryScalingType.dynamic
            self
.rotary_embedding_scale = rotary_embedding_scaling["factor"]
            assert
 self.rotary_embedding_scale > 1.0
        self
.rotary_embedding_dim = 0
        if
 self.position_embedding_type.is_rope():
            self
.rotary_embedding_dim = int(self.attention_head_size *
                                            rotary_embedding_percentage)
            # TODO: Once we add RotaryEmbedding outside GPTAttention plugin,

            #       we need to set it up here


        self
.dtype = dtype
        self
.quant_mode = quant_mode
        if
 use_int8_kv_cache:
            # TODO: remove use_int8_kv_cache as can be replaced by quant_mode.has_kv_cache_quant()

            # Merge int8 setting into quant_mode

            self
.quant_mode = self.quant_mode.set_int8_kv_cache()
        self
.use_int8_kv_cache = use_int8_kv_cache
        if
 self.quant_mode.has_kv_cache_quant():
            self
.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
            self
.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
        else
:
            self
.register_parameter('kv_orig_quant_scale', None)
            self
.register_parameter('kv_quant_orig_scale', None)

        # The output feature size is therefore (h/tp + 2*kvh/tp) * d, where h is num_heads,

        # d is head_size, kvh is the num_kv_heads and tp is tensor_parallel_size.

        # In ColumnLinear op, the output dim is calculated by (h + 2*kvh) * d / tp,

        # which matches the desired output size (h/tp + 2*kvh/tp) * d after splitting

        self
.use_fp8_qdq = self.quant_mode.has_fp8_qdq()
        if
 self.use_fp8_qdq:
            self
.qkv = FP8Linear(hidden_size,
                                 hidden_size +
                                 (2 * tp_size * self.num_attention_kv_heads *
                                  self
.attention_head_size),
                                 bias=bias,
                                 dtype=dtype,
                                 tp_group=tp_group,
                                 tp_size=tp_size,
                                 gather_output=False)
            self
.dense = FP8RowLinear(hidden_size,
                                      hidden_size,
                                      bias=bias,
                                      dtype=dtype,
                                      tp_group=tp_group,
                                      tp_size=tp_size,
                                      instance_id=instance_id)
        else
:
            # 将输入同时投影到Q、K、V空间。输出维度为 hidden_size + (2 * tp_size * num_attention_kv_heads * attention_head_size),对应并行环境下Q、K、V的切分。采用 列并行:权重矩阵按列切分,各设备计算部分输出,最后通过All-Gather通信拼接结果

            # 2=kv

            self
.qkv = ColumnLinear(hidden_size,
                                    hidden_size +
                                    (2 * tp_size * self.num_attention_kv_heads *
                                     self
.attention_head_size),
                                    bias=bias,
                                    dtype=dtype,
                                    tp_group=tp_group,
                                    tp_size=tp_size,
                                    gather_output=False)
            self
.dense = RowLinear(hidden_size,
                                   hidden_size,
                                   bias=bias,
                                   dtype=dtype,
                                   tp_group=tp_group,
                                   tp_size=tp_size,
                                   instance_id=instance_id)

        # per-layer relative attention table

        if
 relative_attention:
            self
.rel_attn_table = Parameter(shape=(num_attention_heads //
                                                   tp_size, num_buckets),
                                            dtype=dtype)

在Transformer模型中,past_key_value 的 Shape(形状) 和它的 数据结构 紧密相关,理解了它的结构,形状就很好记忆了。它的核心作用是在自回归生成(如文本生成)过程中缓存之前所有时间步计算过的键(Key)和值(Value)张量,以避免重复计算,极大提升推理效率

为了让你快速把握全局,下面这个表格总结了其核心形状特征。

场景与层级
数据结构
单个Key或Value张量的Shape
备注
单层注意力 (Self-Attention)
一个元组 (past_key, past_value)
(batch_size, num_heads, sequence_length, embed_size_per_head)
这是最基础的组成单元。
整个模型 (所有层)
一个元组的元组,长度为 n_layers
每层元组包含2个(自注意力)或4个(编码器-解码器)上述张量
例如:past_key_values = ( (layer1_k, layer1_v), (layer2_k, layer2_v), ... )
整个模型 (所有输出)
List[tuple(tulple())],长度为 n_steps
每层元组包含2个(自注意力)或4个(编码器-解码器)上述张量
例如:past_key_values = [Step_1( (layer1_k, layer1_v), (layer2_k, layer2_v), ... ), ..., Step_n]
def forward(
        self,
        hidden_states: Tensor,
        attention_mask=None,
        use_cache=False,
        kv_cache_params=None,
        attention_params=None,
        encoder_output: Optional[Tensor] = None,
        workspace=None,
):

        assert
 isinstance(hidden_states, Tensor)

        alibi_slopes = None
        if
 self.position_embedding_type.is_rope():
            if
 not default_net().plugin_config.gpt_attention_plugin:
                raise
 ValueError(
                    'RoPE is only supported with GPTAttention plugin'
)
        elif
 self.position_embedding_type.is_alibi():
            dtype = trt.float32
            if
 default_net().plugin_config.gpt_attention_plugin:
                dtype = hidden_states.dtype
            alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1.
            alibi_slopes = generate_alibi_slopes(self.num_attention_heads *
                                                 self
.tp_size,
                                                 dtype=dtype,
                                                 tp_size=self.tp_size,
                                                 tp_rank=self.tp_rank,
                                                 alibi_scale=alibi_scale)

        qkv = self.qkv(hidden_states)

        paged_kv_cache = default_net().plugin_config.paged_kv_cache

        assert
 attention_params is None or attention_params.is_valid(
            default_net().plugin_config.gpt_attention_plugin,
            default_net().plugin_config.remove_input_padding)
        assert
 kv_cache_params is None or kv_cache_params.is_valid(
            default_net().plugin_config.gpt_attention_plugin)

             # step1

        past_key_value = None if kv_cache_params is None else kv_cache_params.get_first_past_key_value(
        )
        if
 self.cross_attention and (past_key_value is not None):
            past_key_value = kv_cache_params.past_key_value[1]

        # if cross attention, cross QKV only needs to be calculated once in the

        # 1st decoding step --> write to cross KV cache --> remains constant

        # during the entire decoding. 1st and >1 steps are distinguished by

        # whether past_key_value exists or not

        # also, cross KV cache max length is set from encoder output seqlen,

        # this maps to the max context length concept in decoder-only models

        cross_qkv = None
        # get length data in every run

        if
 encoder_output:
            assert
 isinstance(encoder_output, Tensor)
        # but only do projection once at 1st decoding step

        if
 self.cross_attention and encoder_output:
            cross_qkv = self.qkv(encoder_output)

        if
 default_net().plugin_config.gpt_attention_plugin:
            assert
 self.attention_mask_type in [
                AttentionMaskType.causal, AttentionMaskType.bidirectional
            ], 'Plugin only support masked MHA.'
            kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.quant_mode.has_kv_cache_quant(
            ) else None
            kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.quant_mode.has_kv_cache_quant(
            ) else None
            context, past_key_value = gpt_attention(
                tensor=qkv,
                past_key_value=past_key_value,
                sequence_length=attention_params.sequence_length,
                host_past_key_value_lengths=kv_cache_params.
                host_past_key_value_lengths,
                context_lengths=attention_params.context_lengths,
                cache_indirection=kv_cache_params.cache_indirection,
                host_request_types=attention_params.host_request_types,
                num_heads=self.num_attention_heads,
                num_kv_heads=self.num_attention_kv_heads,
                hidden_size_per_head=self.attention_head_size,
                q_scaling=self.q_scaling,
                rotary_embedding_dim=self.rotary_embedding_dim,
                rotary_embedding_base=self.rotary_embedding_base,
                rotary_embedding_scale_type=self.rotary_embedding_scale_type,
                rotary_embedding_scale=self.rotary_embedding_scale,
                rotary_embedding_max_positions=self.max_position_embeddings,
                position_embedding_type=self.position_embedding_type,
                multi_block_mode=self.multi_block_mode,
                kv_orig_quant_scale=kv_orig_quant_scale,
                kv_quant_orig_scale=kv_quant_orig_scale,
                kv_cache_quant_mode=self.quant_mode,
                max_context_length=attention_params.max_context_length,
                mask_type=self.attention_mask_type,
                alibi_slopes=alibi_slopes,
                tp_size=self.tp_size,
                tp_rank=self.tp_rank,
                kv_cache_block_pointers=kv_cache_params.
                get_first_kv_cache_block_pointers(),
                do_cross_attention=self.cross_attention,
                cross_qkv=cross_qkv,
                cross_qkv_length=attention_params.encoder_max_input_length,
                encoder_input_lengths=attention_params.encoder_input_lengths,
                relative_attention_bias=self.rel_attn_table.value
                if
 self.relative_attention else None,
                max_distance=self.max_distance,
                host_context_lengths=attention_params.host_context_lengths,
            )

        else
:
            # plain TensorRT mode

            assert
 paged_kv_cache == False

            def
 transpose_for_scores(x, is_kv: bool = False):
                _num_attention_heads = self.num_attention_kv_heads if is_kv else self.num_attention_heads
                # x [B, T, D] -> [B, T, H, D//H] -> [B, H, T, D//H]

                new_x_shape = concat([
                    shape(x, 0),
                    shape(x, 1), _num_attention_heads, self.attention_head_size
                ])
                return
 x.view(new_x_shape).permute([0, 2, 1, 3])

            # qkv after projection is of shape

            #   [bs, seqlen, (num_attention_heads + 2 * num_attention_kv_heads), attention_head_size].

            #   [bs, seqlen, tp_size, (num_attention_heads + 2 * num_attention_kv_heads), attention_head_size].

            # The projected and split qkv after transpose_for_scores():

            #   Q[bs, num_attention_heads, seqlen, attention_head_size]

            #   K[bs, num_attention_kv_heads, seqlen, attention_head_size]

            #   V[bs, num_attention_kv_heads, seqlen, attention_head_size]

            # self.hidden_size = num_attention_heads * attention_head_size

            kv_size = self.attention_head_size * self.num_attention_kv_heads
            query, key, value = split(qkv, [self.hidden_size, kv_size, kv_size],
                                      dim=2)

            # in cross attention mode, replace kv by encoder_output

            if
 self.cross_attention and encoder_output is not None:
                encoder_qkv = self.qkv(encoder_output) # [B, T, D]
                _, key, value = split(encoder_qkv,
                                      [self.hidden_size, kv_size, kv_size],
                                      dim=2)

            query = transpose_for_scores(query)
            key = transpose_for_scores(key, is_kv=True)
            value = transpose_for_scores(value, is_kv=True)

            if
 past_key_value is not None:

                def
 dequantize_tensor(x, scale):
                    # Cast from int8 to dtype

                    casted_x = cast(x, self.dtype)
                    return
 casted_x * scale

                if
 self.use_int8_kv_cache:
                    past_key_value = dequantize_tensor(
                        past_key_value, self.kv_quant_orig_scale.value)

                # past_key_value [bs, 2, num_heads, max_seq_len, head_dim]

                past_key, past_value = split(past_key_value, 1, dim=1)

                key_shape = concat([
                    shape(past_key, 0),
                    shape(past_key, 2),
                    shape(past_key, 3),
                    shape(past_key, 4)
                ])
                past_key = past_key.view(key_shape, zero_is_placeholder=False)
                past_value = past_value.view(key_shape,
                                             zero_is_placeholder=False)

                key = concat([past_key, key], dim=2).cast(self.dtype)
                value = concat([past_value, value], dim=2).cast(self.dtype)

            if
 use_cache:
                key_inflated_shape = concat([
                    shape(key, 0), 1,
                    shape(key, 1),
                    shape(key, 2),
                    shape(key, 3)
                ])
                inflated_key = key.view(key_inflated_shape,
                                        zero_is_placeholder=False)
                inflated_value = value.view(key_inflated_shape,
                                            zero_is_placeholder=False)
                past_key_value = concat([inflated_key, inflated_value], dim=1)

                if
 self.use_int8_kv_cache:

                    def
 quantize_tensor(x, scale):
                        scaled = x * scale
                        rounded = round(scaled)
                        clipped = clip(rounded, -128, 127)
                        quantized = cast(clipped, 'int8')
                        return
 quantized

                    past_key_value = quantize_tensor(
                        past_key_value, self.kv_orig_quant_scale.value)

            key_length = shape(key, 2)

            # The following code creates a 2D tensor with 0s in the lower triangular (including the diagonal) and

            # +INF in the upper triangular parts. This bias tensor will be added to the output of the Q*K^T matrix

            # multiplication (BMM1). The +INF elements will be transformed to 0s by the Softmax operator that

            # follows. The elements that corresponds to 0s in the bias are unaffected by the bias tensor.

            #

            # Note that when we added to another bias tensor B (for example, with AliBi), the values in the lower-

            # triangular part of the B tensor are not affected and the upper-triangular ones are set to +INF.

            if
 self.attention_mask_type == AttentionMaskType.causal:
                query_length = shape(query, 2)
                starts = concat([0, 0, key_length - query_length, 0])
                sizes = concat([1, 1, query_length, key_length])
                select_buf = np.expand_dims(
                    np.tril(
                        np.ones((self.max_position_embeddings,
                                 self
.max_position_embeddings))).astype(bool),
                    (0, 1))

                select_buf = np.logical_not(select_buf)
                mask_buf = np.zeros_like(select_buf, np.float32)
                mask_buf[select_buf] = float('-inf')
                buffer = constant(mask_buf)
                causal_mask = slice(buffer, starts, sizes)

            if
 attention_mask is not None:
                attention_mask = expand_mask(attention_mask, shape(query, 2))
            bias = attention_mask
            if
 self.position_embedding_type.is_alibi():
                alibi_biases = generate_alibi_biases(alibi_slopes, key_length)
                bias = alibi_biases if bias is None else bias + alibi_biases

            key = key.permute([0, 1, 3, 2])
            with
 precision('float32'):
                attention_scores = matmul(cast(query, 'float32'),
                                          cast(key, 'float32'))

                attention_scores = attention_scores / self.norm_factor

                if
 self.attention_mask_type == AttentionMaskType.causal:
                    bias = causal_mask if bias is None else bias + causal_mask

                if
 bias is not None and not self.cross_attention:
                    attention_scores = attention_scores + bias

            attention_probs = softmax(attention_scores, dim=-1)

            context = matmul(attention_probs, value).permute([0, 2, 1, 3])
            context = context.view(
                concat([shape(context, 0),
                        shape(context, 1), self.hidden_size]))

        context = self.dense(context, workspace)

        if
 use_cache:
            return
 (context, past_key_value)
        else
:
            return
 context
@gw.record_signature
def
 gpt_attention(
        tensor: Tensor,
        past_key_value: Tensor,
        sequence_length: Tensor,
        host_past_key_value_lengths: Tensor,
        context_lengths: Tensor,
        cache_indirection: Tensor,
        host_request_types: Tensor,
        num_heads: int,
        num_kv_heads: int,
        hidden_size_per_head: int,
        q_scaling: float,
        rotary_embedding_dim: int,
        rotary_embedding_base: float = 10000.0,
        rotary_embedding_scale_type: RotaryScalingType = RotaryScalingType.none,
        rotary_embedding_scale: float = 1.0,
        rotary_embedding_max_positions: int = 1024,
        position_embedding_type: PositionEmbeddingType = PositionEmbeddingType.
    learned_absolute,
        multi_block_mode: bool = False,
        kv_orig_quant_scale: Tensor = None,
        kv_quant_orig_scale: Tensor = None,
        kv_cache_quant_mode: QuantMode = None,
        max_context_length: int = None,
        mask_type: AttentionMaskType = AttentionMaskType.causal,
        alibi_slopes: Tensor = None,
        tp_size: int = 1,
        tp_rank: int = 0,
        kv_cache_block_pointers: Tensor = None,
        do_cross_attention: bool = False,
        cross_qkv: Tensor = None,  # for cross attention
        cross_qkv_length: Tensor = None,  # for cross attention
        encoder_input_lengths: Tensor = None,  # for cross attention
        relative_attention_bias: Tensor = None,  # for relative attention
        max_distance: int = 0,  # for relative attention
        host_context_lengths: Tensor = None,  # for pad-free input mode
        qkv_bias: Tensor = None
) -> Tuple[Tensor]:
    '''
    Add an operation that performs the multi-head attention in GPT-like models.

    The signature of the function will change in the future release - we are in
    the process of simplifying the API. The current version is still
    work-in-progress! The following API is provided with hints regarding the
    arguments that are likely to be removed or merged with others in the future
    release.

    See docs/gpt_attention.md for the documentation of that function.

    Parameters:
        tensor: Tensor
            The input QKV tensor. Its shape is [batch_beam_size, max_seqlen, 3
            * hidden_dim] in padded mode and [1, num_tokens, 3 * hidden_dim] in
            packed mode. See QKV Input in docs/gpt_attention.md.

        past_key_value: Tensor
            The tensor that stores KV cache data. Its shape is
            [max_batch_size * max_beam_width, 2, num_heads, max_seqlen, hidden_dim_per_head]
            in contiguous mode and
            [max_blocks, 2, num_heads, num_tokens_per_block, hidden_dim_per_head]
            in paged mode. See KV Cache in docs/gpt_attention.md,

        sequence_lengths: Tensor
            The tensor that stores the length of each sequence. Its shape is
            [batch_size]. See QKV Input in docs/gpt_attention.md,

        host past_key_value_length: Tensor
            An INT32 tensor of shape [batch_size].

        context_lengths: Tensor
            The tensor that stores the context-phase sequence length of each request. Its shape
            is [batch_size]. See QKV Input in doc/functional.py,

        cache_indirection: Tensor
            The tensor to reconstruct the paths when using beam-search. Its
            shape is [batch_size, beam_width, max_seqlen]. See Beam-Search in
            docs/gpt_attention.md,

        host_request_types: Tensor = None
            The tensor on the host that indicates if a request is in context or
            generation phase. Its shape is [batch_size]. See Inflight Batching
            in docs/gpt_attention.md,

        num_heads: int
            The number of heads,

        num_kv_heads: int
            The number of KV heads, generic to handle MHA/MQA/GQA,

        hidden_size_per_head: int
            The hidden size per head,

        q_scaling: float
            The value used to compute the scaling factor applied to the output
            of the Q*K^T product. See Scaling Factors in docs/gpt_attention.md,

        rotary_embedding_dim: int
            The dimension to compute RoPE. Use 0 when position_embedding_type is not RoPE.

        rotary_embedding_base: float
            The theta value to use for RoPE. Ignored when position_embedding_type is not RoPE.

        rotary_embedding_scale_type: RotaryScalingType
            The scaling type of RoPE. Ignored when position_embedding_type is not RoPE.
            Possible rotary scaling type:
                * RotaryScalingType.none
                * RotaryScalingType.linear
                * RotaryScalingType.dynamic

        rotary_embedding_scale: float
            The scale value to use for linear/dynamic scaling in RoPE.
            Ignored when position_embedding_type is not RoPE.
            Must be set to 1 (default) if rotary_embedding_scale_type is `none`.

        rotary_embedding_max_positions: int
            Needed only for `dynamic` RoPE scaling. Ignored otherwise.

        position_embedding_type: PositionEmbeddingType
            The position embedding type:
                * PositionEmbeddingType.learned_absolute
                * PositionEmbeddingType.relative
                * PositionEmbeddingType.rope_gptj
                * PositionEmbeddingType.rope_gpt_neox
                * PositionEmbeddingType.alibi
                * PositionEmbeddingType.alibi_with_scale

        multi_block_mode: bool
            Do we enable multi-block for the masked MHA. See Generation Phase
            in docs/gpt_attention.md,

        kv_orig_quant_scale: Tensor
            The tensor to store the scaling factor for quantization to INT8/FP8
            in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in
            docs/gpt_attention.md,

        kv_quant_orig_scale: Tensor
            The tensor to store the scaling factor for dequantization from
            INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache
            in docs/gpt_attention.md,

        kv_cache_quant_mode: QuantMode (int flags)
            Do we enable the INT8 or FP8 KV cache?

        max_context_length: int32_t
            The length of the longest input sequence. See QKV Input in
            docs/gpt_attention.md,

        mask_type: int = 1
            The type of mask:
                * tensorrt_llm.layers.AttentionMaskType.padding for BERT,
                * tensorrt_llm.layers.AttentionMaskType.causal for GPT,
                * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM,

        alibi_slopes: Tensor
            The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel
            when possible,

        tp_size: int
            The number of processes/GPUs when tensor parallelism is activated,

        tp_rank: int
            The rank of that process (when running tensor parallelism),

        kv_cache_block_pointers:
            The tensor of block pointers for the KV cache. Its shape is
            [max_batch_size, max_beam_width, 2, max_blocks_per_sequence * 2]
            See KV cache section in docs/gpt_attention.md,

        do_cross_attention: bool = False
            Do we use this as cross attention instead of self attention,

        cross_qkv: Tensor = None
            The QKV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 3
            * hidden_dim] in padded mode and [1, num_tokens, 3 * hidden_dim] in
            packed mode,

        cross_qkv_length: Tensor = None
            The length of the longest encoder output sequence,

        encoder_input_lengths: Tensor
            The tensor that stores the length of each encoder input sequence. Its shape is [batch_size],

        relative_attention_bias: Tensor = None
            The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].

        max_distance: int = 0
            The maximum distance of relative position in attention, for implicit mode.
            Default value is 0, meaning to use the regular mode of relative attention bias.
            Implicit mode is only enabled when passing in non-zero positive max_distance value.
            See relative attention bias in docs/gpt_attention.md

        host_context_lengths: Tensor = None
            A host tensor that contains the lengths of the different inputs,

        qkv_bias: Tensor = None,

    Returns:
        The tensor produced by that layer.
    '''

    assert
 host_request_types is not None
    assert
 (alibi_slopes is not None) == (position_embedding_type.is_alibi())
    attn_plg_creator = trt.get_plugin_registry().get_plugin_creator(
        'GPTAttention'
, '1', TRT_LLM_PLUGIN_NAMESPACE)
    assert
 attn_plg_creator is not None
    assert
 host_context_lengths is not None or not default_net(
    ).plugin_config.remove_input_padding
    assert
 isinstance(max_context_length, int)

    paged_kv_cache_flag = default_net().plugin_config.paged_kv_cache

    nheads = trt.PluginField("num_heads", np.array(num_heads, dtype=np.int32),
                             trt.PluginFieldType.INT32)
    num_kv_heads = trt.PluginField("num_kv_heads",
                                   np.array(num_kv_heads, dtype=np.int32),
                                   trt.PluginFieldType.INT32)
    head_size = trt.PluginField("head_size",
                                np.array(hidden_size_per_head, dtype=np.int32),
                                trt.PluginFieldType.INT32)
    unidirectional = trt.PluginField("unidirectional",
                                     np.array(1, dtype=np.int32),
                                     trt.PluginFieldType.INT32)
    q_scaling = trt.PluginField("q_scaling",
                                np.array(q_scaling, dtype=np.float32),
                                trt.PluginFieldType.FLOAT32)
    rotary_embedding_dim = trt.PluginField(
        "rotary_embedding_dim"
, np.array(rotary_embedding_dim, dtype=np.int32),
        trt.PluginFieldType.INT32)
    rotary_embedding_base = trt.PluginField(
        "rotary_embedding_base"
,
        np.array(rotary_embedding_base, dtype=np.float32),
        trt.PluginFieldType.FLOAT32)
    rotary_embedding_scale_type = trt.PluginField(
        "rotary_embedding_scale_type"
,
        np.array(rotary_embedding_scale_type, dtype=np.int8),
        trt.PluginFieldType.INT8)
    rotary_embedding_scale = trt.PluginField(
        "rotary_embedding_scale"
,
        np.array(rotary_embedding_scale, dtype=np.float32),
        trt.PluginFieldType.FLOAT32)
    rotary_embedding_max_positions = trt.PluginField(
        "rotary_embedding_max_positions"
,
        np.array(rotary_embedding_max_positions, dtype=np.int32),
        trt.PluginFieldType.INT32)
    position_embedding_type = trt.PluginField(
        "position_embedding_type"
,
        np.array(int(position_embedding_type), dtype=np.int8),
        trt.PluginFieldType.INT8)
    context_fmha_type = trt.PluginField(
        "context_fmha_type"
,
        np.array(np.int8(default_net().plugin_config.context_fmha_type),
                 dtype=np.int8), trt.PluginFieldType.INT8)
    remove_input_padding = trt.PluginField(
        "remove_input_padding"
,
        np.array(np.int8(default_net().plugin_config.remove_input_padding),
                 dtype=np.int8), trt.PluginFieldType.INT8)
    p_dtype = default_net().plugin_config.gpt_attention_plugin
    pf_type = trt.PluginField(
        "type_id"
, np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
        trt.PluginFieldType.INT32)
    mask_type = trt.PluginField("mask_type", np.array([int(mask_type)],
                                                      np.int32),
                                trt.PluginFieldType.INT32)
    multi_block_mode = trt.PluginField(
        "multi_block_mode"
, np.array(np.int8(multi_block_mode), dtype=np.int8),
        trt.PluginFieldType.INT8)
    tp_size = trt.PluginField("tp_size", np.array(tp_size, dtype=np.int32),
                              trt.PluginFieldType.INT32)
    tp_rank = trt.PluginField("tp_rank", np.array(tp_rank, dtype=np.int32),
                              trt.PluginFieldType.INT32)
    kv_cache_quant_mode_field = trt.PluginField(
        "kv_cache_quant_mode"
,
        np.array(np.int8(kv_cache_quant_mode), dtype=np.int32),
        trt.PluginFieldType.INT32)
    paged_kv_cache = trt.PluginField(
        "paged_kv_cache"
, np.array(paged_kv_cache_flag, dtype=np.int32),
        trt.PluginFieldType.INT32)
    tokens_per_block = trt.PluginField(
        "tokens_per_block"
,
        np.array(default_net().plugin_config.tokens_per_block, dtype=np.int32),
        trt.PluginFieldType.INT32)
    max_context_length = trt.PluginField("max_context_length",
                                         np.array(max_context_length, np.int32),
                                         trt.PluginFieldType.INT32)
    if
 qkv_bias is None:
        qkv_bias_enabled = trt.PluginField("qkv_bias_enabled",
                                           np.array(0, dtype=np.int8),
                                           trt.PluginFieldType.INT8)
    else
:
        qkv_bias_enabled = trt.PluginField("qkv_bias_enabled",
                                           np.array(1, dtype=np.int8),
                                           trt.PluginFieldType.INT8)
    do_cross_attention_field = trt.PluginField(
        "do_cross_attention"
,
        np.array(np.int8(do_cross_attention), dtype=np.int8),
        trt.PluginFieldType.INT8)
    max_distance = trt.PluginField("max_distance",
                                   np.array(max_distance, dtype=np.int32),
                                   trt.PluginFieldType.INT32)

    pfc = trt.PluginFieldCollection([
        nheads, num_kv_heads, head_size, unidirectional, q_scaling,
        position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
        rotary_embedding_scale_type, rotary_embedding_scale,
        rotary_embedding_max_positions, tp_size, tp_rank, context_fmha_type,
        multi_block_mode, kv_cache_quant_mode_field, remove_input_padding,
        mask_type, paged_kv_cache, tokens_per_block, pf_type,
        max_context_length, qkv_bias_enabled, do_cross_attention_field,
        max_distance
    ])

    attn_plug = attn_plg_creator.create_plugin("causal_attn", pfc)
    plug_inputs = [
        tensor,
        sequence_length,
        host_past_key_value_lengths,
        context_lengths,
        cache_indirection,
        host_request_types,
    ]

    if
 paged_kv_cache_flag:
        plug_inputs += [kv_cache_block_pointers]
    else
:
        plug_inputs += [past_key_value]

    if
 kv_cache_quant_mode.has_kv_cache_quant():
        plug_inputs += [kv_orig_quant_scale, kv_quant_orig_scale]

    if
 alibi_slopes is not None:
        plug_inputs += [alibi_slopes]

    if
 relative_attention_bias is not None:
        plug_inputs += [relative_attention_bias]

    if
 do_cross_attention:
        plug_inputs += [cross_qkv, cross_qkv_length, encoder_input_lengths]

    if
 default_net().plugin_config.remove_input_padding:
        plug_inputs += [host_context_lengths]

    if
 qkv_bias is not None:
        plug_inputs += [qkv_bias]

    plug_inputs = [i.trt_tensor for i in plug_inputs]
    layer = default_trtnet().add_plugin_v2(plug_inputs, attn_plug)
    output = _create_tensor(layer.get_output(0), layer)
    present_key_value = None
    if
 not paged_kv_cache_flag:
        present_key_value = _create_tensor(layer.get_output(1), layer)
        assert
 present_key_value is not None
        expected_outputs = 2
    else
:
        expected_outputs = 1

    assert
 layer.num_outputs == expected_outputs, \
        f"Plugin outputs number mismatch with expected, got {layer.num_outputs}, expected {expected_outputs}"


    if
 kv_cache_quant_mode.has_int8_kv_cache() and not paged_kv_cache_flag:
        # past key value

        layer.get_input(6).set_dynamic_range(-127, 127)
        # present key value

        layer.get_output(1).set_dynamic_range(-127, 127)

    assert
 output is not None
    return
 output, present_key_value

BertAttention

class BertAttention(Module):

    def
 __init__(self,
                 hidden_size,
                 num_attention_heads,
                 num_kv_heads=None,
                 max_position_embeddings=1024,
                 num_layers=1,
                 q_scaling=1.0,
                 apply_query_key_layer_scaling=False,
                 bias=True,
                 dtype=None,
                 tp_group=None,
                 tp_size=1,
                 tp_rank=0,
                 relative_attention=False,
                 max_distance=0,
                 num_buckets=0
):
        super
().__init__()

        self
.attention_head_size = hidden_size // num_attention_heads
        self
.num_attention_heads = num_attention_heads // tp_size
        self
.num_attention_kv_heads = (
            num_kv_heads + tp_size - 1
        ) // tp_size if num_kv_heads is not None else self.num_attention_heads
        self
.hidden_size = hidden_size // tp_size
        self
.max_position_embeddings = max_position_embeddings
        self
.norm_factor = math.sqrt(self.attention_head_size)
        self
.tp_size = tp_size
        self
.tp_rank = tp_rank

        self
.num_layers = num_layers
        self
.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self
.norm_factor = math.sqrt(self.attention_head_size)
        self
.q_scaling = q_scaling
        if
 self.apply_query_key_layer_scaling:
            self
.norm_factor *= self.num_layers
            self
.q_scaling *= self.num_layers

        self
.dtype = dtype

        self
.relative_attention = relative_attention
        self
.max_distance = max_distance

        self
.qkv = ColumnLinear(hidden_size,
                                hidden_size +
                                (2 * tp_size * self.num_attention_kv_heads *
                                 self
.attention_head_size),
                                bias=bias,
                                dtype=dtype,
                                tp_group=tp_group,
                                tp_size=tp_size,
                                gather_output=False)
        self
.dense = RowLinear(hidden_size,
                               hidden_size,
                               bias=bias,
                               dtype=dtype,
                               tp_group=tp_group,
                               tp_size=tp_size)

        # per-layer relative attention table

        if
 relative_attention:
            self
.rel_attn_table = Parameter(shape=(num_attention_heads //
                                                   tp_size, num_buckets),
                                            dtype=dtype)

    def
 forward(self,
                hidden_states: Tensor,
                attention_mask=None,
                input_lengths=None
):
        assert
 isinstance(hidden_states, Tensor)

        qkv = self.qkv(hidden_states)

        if
 default_net().plugin_config.bert_attention_plugin:
            # TRT plugin mode

            assert
 input_lengths is not None
            context = bert_attention(
                qkv,
                input_lengths,
                self
.num_attention_heads,
                self
.attention_head_size,
                q_scaling=self.q_scaling,
                relative_attention=self.relative_attention,
                max_distance=self.max_distance,
                relative_attention_bias=self.rel_attn_table.value
                if
 self.relative_attention else None)
        else
:
            # plain TRT mode

            def
 transpose_for_scores(x):
                new_x_shape = concat([
                    shape(x, 0),
                    shape(x, 1), self.num_attention_heads,
                    self
.attention_head_size
                ])
                return
 x.view(new_x_shape).permute([0, 2, 1, 3])

            query, key, value = split(qkv, self.hidden_size, dim=2)
            query = transpose_for_scores(query)
            key = transpose_for_scores(key)
            value = transpose_for_scores(value)

            key = key.permute([0, 1, 3, 2])
            attention_scores = matmul(query, key)
            attention_scores = attention_scores / self.norm_factor

            if
 attention_mask is not None:
                attention_scores = attention_scores + attention_mask

            attention_probs = softmax(attention_scores, dim=-1)

            context = matmul(attention_probs, value).permute([0, 2, 1, 3])
            context = context.view(
                concat([shape(context, 0),
                        shape(context, 1), self.hidden_size]))

        context = self.dense(context)

        return
 context
# tensor是packed的,即no-padding。
def
 bert_attention(tensor: Tensor,
                   input_lengths: Tensor,
                   num_heads: int,
                   head_size: int,
                   q_scaling: float,
                   relative_attention: bool = False,
                   relative_attention_bias: Tensor = None,
                   max_distance: int = 0
) -> Tuple[Tensor]:
    '''
    Add an operation that performs the multi-head attention in BERT.

    The multihead-attention (MHA) is the sequence of a batched matmul, a
    softmax and a batched matmul as described in
    https://arxiv.org/abs/1706.03762. That function adds an operation that
    performs those computations using a single GPU kernel.

    The input tensor contains the Q, K and V elements. It is a 2D tensor and
    its shape is '[sum_of_tokens, 3*hidden_dim]' where the 'sum_of_tokens' is
    the sum of the sequence lengths in the batch.

    In MHA, the output of the Q*K^T product is scaled by a constant value that
    is computed as:

        1.f / (q_scaling * sqrt(head_size)).

    That 'q_scaling' constant is the last argument of that function.

    That layer is implemented using a plugin (see bertAttentionPlugin).

    Parameters:
        tensor : Tensor
            The QKV input tensor.

        input_lengths : Tensor
            The length of each sequence. It is a 1D tensor of size 'batch_size'.

        num_heads : int
            The number of heads.

        head_size : int
            The size of each head.

        q_scaling : float
            The factor to compute the scaling factor to scale the output of the
            'Q*K^T' product.

        relative_attention: bool = False
            If enable relative attention.

        relative_attention_bias: Tensor = None
            The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].

        max_distance: int = 0
            The maximum distance of relative position in attention, for implicit mode.
            Default value is 0, meaning to use the regular mode of relative attention bias.
            Implicit mode is only enabled when passing in non-zero positive max_distance value.
            See relative attention bias in docs/gpt_attention.md

    Returns:
        The tensor produced by that layer.
    '''

    attn_plg_creator = trt.get_plugin_registry().get_plugin_creator(
        'BertAttention'
, '1', TRT_LLM_PLUGIN_NAMESPACE)
    assert
 attn_plg_creator is not None

    nheads = trt.PluginField("num_heads", np.array(num_heads, dtype=np.int32),
                             trt.PluginFieldType.INT32)
    head_size = trt.PluginField("head_size", np.array(head_size,
                                                      dtype=np.int32),
                                trt.PluginFieldType.INT32)
    q_scaling = trt.PluginField("q_scaling",
                                np.array(q_scaling, dtype=np.float32),
                                trt.PluginFieldType.FLOAT32)
    enable_qk_half_accum = trt.PluginField(
        "enable_qk_half_accum"
,
        np.array(np.int8(
            default_net().plugin_config.attention_qk_half_accumulation),
                 dtype=np.int8), trt.PluginFieldType.INT8)
    context_fmha_type = trt.PluginField(
        "context_fmha_type"
,
        np.array(np.int8(default_net().plugin_config.context_fmha_type),
                 dtype=np.int8), trt.PluginFieldType.INT8)
    p_dtype = default_net().plugin_config.bert_attention_plugin
    pf_type = trt.PluginField(
        "type_id"
, np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
        trt.PluginFieldType.INT32)
    do_relative_attention = trt.PluginField(
        "do_relative_attention"
,
        np.array(np.int8(relative_attention), dtype=np.int8),
        trt.PluginFieldType.INT8)
    max_distance = trt.PluginField("max_distance",
                                   np.array(max_distance, dtype=np.int32),
                                   trt.PluginFieldType.INT32)
    pfc = trt.PluginFieldCollection([
        nheads, head_size, q_scaling, enable_qk_half_accum, context_fmha_type,
        pf_type, do_relative_attention, max_distance
    ])

    attn_plug = attn_plg_creator.create_plugin("padding_attn", pfc)
    plug_inputs = [tensor, input_lengths]
    if
 relative_attention_bias is not None:
        plug_inputs += [relative_attention_bias]

    plug_inputs = [i.trt_tensor for i in plug_inputs]

    layer = default_trtnet().add_plugin_v2(plug_inputs, attn_plug)
    assert
 layer.num_outputs == 1, \
        f"Plugin outputs number mismatch with expected, got {layer.num_outputs}, expected 1"

    output = _create_tensor(layer.get_output(0), layer)
    assert
 output is not None
    return
 output

Others

def generate_alibi_slopes(num_heads: int,
                          dtype: trt.DataType = trt.float32,
                          tp_size: int = 1,
                          tp_rank: int = 0,
                          alibi_scale: float = 1.0
) -> Tensor:
    '''
    Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.

    Parameters:
        num_heads : int
            The number of heads.
        dtype : trt.DataType
            The data type of the returned slopes
        tp_size : int
            The tensor parallelism size
        tp_rank : int
            The tensor parallelism rank

    Returns:
        A constant tensor that contains the ALiBi slopes.
    '''

    start_head_id = 0
    end_head_id = num_heads

    if
 tp_size > 1:
        rank_heads = num_heads // tp_size
        start_head_id = rank_heads * tp_rank
        end_head_id = start_head_id + rank_heads

    closest_power_of_2 = 2**np.floor(np.log2(num_heads))
    # FT's implementation

    # https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/gen_relative_pos_bias.cu#L248

    slopes_ft = []
    for
 h_id in range(start_head_id, end_head_id):
        if
 h_id < closest_power_of_2:
            slopes_ft.append(
                np.power(2**(-(2**-(np.log2(closest_power_of_2) - 3))),
                         h_id + 1))
        else
:
            slopes_ft.append(
                np.power(2**(-(2**-(np.log2(closest_power_of_2 * 2) - 3))),
                         (h_id - closest_power_of_2) * 2 + 1))
    slopes = np.asarray(slopes_ft, dtype=np.float32)

    slopes = alibi_scale * slopes
    # Note that for bfloat16, we cannot case numpy tensor from float32 to bfloat16

    # becuases numpy does not support bfloat16. Even if we use custom type to define

    # the np_bfloat16, the "astype" here would be undefined.

    # So, we must use torch to cast tensor from float32 to bfloat16, and then use torch_to_numpy

    # to cast the tensor back.

    slopes = torch.from_numpy(slopes)
    slopes = slopes.to(trt_dtype_to_torch(dtype))
    slopes = torch_to_numpy(slopes)
    slopes = constant(slopes.reshape(1, (end_head_id - start_head_id), 1, 1))
    return
 slopes


def
 generate_alibi_biases(slopes: Tensor, key_length: Tensor) -> Tensor:
    '''
    Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.

    The ALiBi biases are added to the result of the Q*K^T product in the
    multihead-attention block.

    Parameters:
        slopes : Tensor
            The slopes.

        key_length : Tensor
            The size of the K vector per head.

    Returns:
        A constant tensor that contains the ALiBi biases.
    '''

    # We don't need to care about the batch size or query length since we can just broadcast

    # across the batch and query dimensions


    trt_0 = constant(int32_array(0))
    arange_shape = concat([1, 1, 1, key_length])

    arange_tensor = arange(trt_0, key_length, "float32").view(arange_shape)
    arange_tensor = cast(arange_tensor, "float32")
    return
 slopes * arange_tensor
def expand_mask(mask: Tensor, tgt_len: Optional[Tensor] = None) -> Tensor:
    '''
    Expand an attention mask.

    That function adds the sequence of operations to expand from a tensor of
    shape '[batch_size, src_seq_len]' to a tensor of shape
    '[batch_size, 1, tgt_seq_len, src_seq_len]'. It can be used to create the
    mask applied to the Q*K^T product before the softmax operation in the
    multihead-attention block.

    Parameters:
        mask : Tensor
            The input mask

        tgt_len : Optional[Tensor]
            The dimension of the 3rd dimension in the output tensor. If None,
            the 2nd dimension of the input is used.

    Returns:
        The tensor created by that sequence of operations.
    '''

    bsz = shape(mask, 0)
    src_len = shape(mask, 1)
    tgt_len = tgt_len if tgt_len is not None else src_len

    mask = mask.view(concat([bsz, 1, 1, src_len]))

    mask = expand(mask, concat([bsz, 1, tgt_len, src_len]))
    mask = where(mask == 0, float('-inf'), (1 - mask).cast('float32'))
    return
 mask
def gather_last_token_logits(hidden_states: Tensor, last_token_ids: Tensor,
                             remove_input_padding: bool
) -> Tensor:
    '''
    Extract the logits that correspond to the last token from the hidden states.

    That function adds the operations to extract the logits of the last tokens
    in a batch of sequences.

    Depending on whether 'remove_input_padding' is 'True' or 'False', that
    function assumes inputs of different shapes.

    When 'remove_input_padding' is 'True', the 'hidden_states' tensor is
    assumed to be packed. It has a shape '[num_tokens, hidden_dim]' where
    'num_tokens' is the sum of the lengths of the sequences in the batch and
    'hidden_dim' is the hidden dimension. The 'last_tokens_ids' is a 1D tensor
    that encodes the inclusive prefix-sums of the lengths of the sequences in
    the batch.

    When 'remove_input_padding' is 'False', the 'hidden_states' tensor is
    assumed to be padded. It has a shape '[batch_size, max_seqlen, hidden_dim]'
    where 'max_seqlen' is the length of the longest sequence in the batch and
    'hidden_dim' is the hidden dimension.  The 'last_token_ids' is a 1D tensor
    that encodes the length of each sequence in the batch.

    In both cases, that function produces a tensor of shape '[batch_size,
    hidden_size]' where the row at index 'i' corresponds to the logits of the
    last token from the 'i'-th sequence.

    Parameters:
        hidden_states : Tensor
            The hidden states

        last_token_ids : Tensor
            The inclusive prefix-sum of the lengths or the lenghts of the
            sequences in the batch.

        remove_input_padding : bool
            Indicate if the hidden_states are packed ('True') or padded
            ('False').

    Returns:
        The tensor created by that sequence of operations.
    '''

    if
 last_token_ids is None:
        return
 hidden_states

    if
 remove_input_padding:
        # hidden_states.shape = [1, num_tokens, hidden_dim]

        # last_token_ids.shape = (B,)

        hidden_states = index_select(hidden_states, 1,
                                     last_token_ids - 1)  # [1, seq_len, hidden]

        hidden_states = hidden_states.view(
            concat([shape(last_token_ids, 0),
                    shape(hidden_states, 2)])) # [B, D]
    else
:
        # only calculate logits for the last token

        # [batch_size, seqlen, hidden_size] -> [batch_size, hidden_size]

        last_token_ids = last_token_ids.view(
            concat([shape(last_token_ids, 0), 1, 1]))
        last_token_ids = expand(
            last_token_ids,
            concat([shape(last_token_ids, 0), 1,
                    shape(hidden_states, 2)]))
        last_token_ids = last_token_ids - 1
        hidden_states = gather(
            hidden_states, dim=1, indices=last_token_ids).view(
                concat([shape(hidden_states, 0),
                        shape(hidden_states, 2)]))
    return
 hidden_states

参考文献

  • • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/layers/attention.py
点个「赞」+「在看」❤️
让我们知道这份文字有温暖到你,也是我们持续创作的最大动力!
推荐
Lock-Free 队列实现原理
Share Memory 的 Bank Conflict
告别高成本!TensorRT-LLM实战:如何将LLM推理速度提升数倍
使用LoRA对LLM进行微调的实用技巧
强化学习小白必看:PTX Loss 到底是个啥?
GPT-5 Prompt Migration and Improvement Using the New Optimizer
Task 异步流 coroutine 实现
C++ corotine 介绍
搭建 VSCode 离线开发环境
nlohmann/json 库简介
Intro to C++ Coroutines: Concept
Hugging Face BPE Tokenizer 的资源文件
移动语义 std::move 和完美转发 std::forward
ACEBench: Who Wins the Match Point in Tool Usage?
什么是 GN
RULER: Relative Universal LLM-Elicited Rewards
SFT和RFT的区别
CosyVoice 3: 面向真实场景的大规模零样本语音生成模型
CosyVoice 3: Towards In-the-wild Speech Generation
语音合成(TTS)中文自然度:问题、成因、解决方案
上下文工程如何实现
上下文工程(Context Engineering)
新手必看!LangGraph 101:手把手教你搭一个深度研究 Agent
LangGraph 简介
SFT 泛化新解读:强化学习 + 奖励修正,一文读懂
程序员狂喜!Self-Instruct 框架全解析:无限生成高质量指令集,从此告别标注噩梦!
Evol-Instruct 竟能精准生成领域专属数据?实操技巧速看!
指令微调数据-少即是多
LLM generate 参数怎么用?
语音合成(TTS)跳跃与重复问题的解析:成因、机制及解决方案
大模型训练新思路:GEPA 靠 “反思” 赢过 RL,看完秒懂
F5-TTS:用 Flow Matching 玩转语音,流畅度和真实感都 “拉满” 了
E2 TTS:令人尴尬地简单、完全非自回归、零样本的语音合成技术
Voicebox:大规模文本引导的多语言通用语音生成技术
为什么都在聊 Kimi K2?Open Agentic Intelligence 藏着哪些新惊喜
Step-Audio-AQAA 端到端音频模型
DPO、PPO、GRPO的原理,区别与联系
OPENCSG 中文语料库:一系列高质量的中文数据集,用于语言模型训练
什么是 Classifier-Free Guidance?
Conditional Flow Matching : 连续标准流 Continuous Normalizing Flow
CFM 与 OT-CFM:条件流匹配与最优传输的碰撞
DPO损失实现
Conditional Flow Matching : 常微分方程ODE、欧拉方法和Neural ODE
当 Normalizing flow 遇上语音生成:AI 说话变 “真人” 的秘密在这里!
深度剖析:Kimi – Audio 中 BigVGAN 的神奇作用
为什么说分布变换是 Normalizing flow 的「灵魂操作」?
MATCHA-TTS 来了!条件流匹配让文本转语音效率飙升
从知识增长的角度提升RAG上下文的质量
MiniMax-Speech,零样本语音合成新突破,32 种语言轻松拿捏!
手把手教你创建 evol-instruct 数据集!附完整流程~
社交类聊天的 Query 分析与应答策略
SFT 中指令选择和响应选择哪个更重要?
角色扮演大模型技术分享2-超拟人模型的困境
最新!SpeechLLM 综述:架构、能力、挑战与未来全揭秘
如何低成本生成高质量指令微调数据?
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
Kimi-Audio:开源音频基础模型全面解析
Kimi-Audio 的 TTS 效果如何?
Qwen 的训练数据是怎么做的?
GeForce RTX 3090, 4090, A10, A40, A100, A800, L20, L40 显卡性能对比
如何低成本生成高质量指令微调数据?
掌握RAG:投入生产前要评估的8个场景
掌握RAG:如何评估RAG的LLM
掌握RAG:如何在部署后观察您的RAG
掌握RAG:如何选择嵌入模型
基础模型中的新范式:为什么o1是不同的,以及它将如何改变LLM应用
Semantic token和连续特征在SLLM下的对比
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
RLHF及其变体:进展和实际工程见解
Freeze-Omni: 低延迟语音对话模型
Fully Sharded Data Parallelism (FSDP)
什么是置信度?置信度模型怎么做?
晦涩难懂的 Flow matching!图形化理解
中文指令微调数据,质量就是一切!
基于 LLM 的文本泛化
CosyVoice 2:基于大型语言模型的可扩展流式语音合成技术
Mini-Omni2: with Vision, Speech and Duplex Capabilities
FSQ的原理与VQ-VAE的区别和联系
大模型并行训练的一些知识——极简版
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
要用 AI 裁员 50% 的千亿独角兽,公开认错,重启招聘!
single codebook和dual codebook在LLM中向量量化上有什么区别?
一些文档去重算法
最佳的指令数据应当是什么样的?
Prefill-Decode分离
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
Simhash-文档去重算法简介
RLHF 入门,高手勿进!
最佳的指令数据应当是什么样的?
CosyVoice:一种基于监督式语义标记的可扩展多语言 Zero-Shot 语音合成器
Model Context Protocol (MCP)
MCP(模型上下文协议)是什么以及它是如何运作的
压力测试LLMs——大海捞针实现