LLaMADecoderLayer
class LLaMADecoderLayer(Module):
def __init__(self,
layer_id,
hidden_size,
num_attention_heads,
num_kv_heads=None,
max_position_embeddings=2048,
dtype=None,
attention_mask_type=AttentionMaskType.causal,
hidden_act='silu',
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_base=10000.0,
rotary_scaling=None,
mlp_hidden_size=None,
tp_group=None,
tp_size=1,
quant_mode=QuantMode(0),
rms_norm_eps=1e-06):
super().__init__()
self._layer_id = layer_id # useful for debugging
# used for quantizing model
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_kv_heads = num_kv_heads
self.max_position_embeddings = max_position_embeddings
self.dtype = dtype
self.hidden_act = hidden_act
self.tp_group = tp_group
self.tp_size = tp_size
self.mlp_hidden_size = mlp_hidden_size
self.attention_mask_type = attention_mask_type
self.position_embedding_type = position_embedding_type
self.input_layernorm = RmsNorm(normalized_shape=hidden_size,
eps=rms_norm_eps,
dtype=dtype)
self.attention = Attention(
hidden_size,
num_attention_heads,
num_kv_heads,
max_position_embeddings,
dtype=dtype,
attention_mask_type=AttentionMaskType.causal,
bias=False,
position_embedding_type=position_embedding_type,
rotary_embedding_base=rotary_base,
rotary_embedding_scaling=rotary_scaling,
tp_group=tp_group,
tp_size=tp_size,
use_int8_kv_cache=quant_mode.has_int8_kv_cache(),
quant_mode=quant_mode,
instance_id=2 * layer_id,
)
if not mlp_hidden_size:
self.mlp_hidden_size = hidden_size * 4
self.mlp = GatedMLP(hidden_size=hidden_size,
ffn_hidden_size=self.mlp_hidden_size,
hidden_act=hidden_act,
dtype=dtype,
bias=False,
tp_group=tp_group,
tp_size=tp_size,
quant_mode=quant_mode,
instance_id=2 * layer_id + 1)
self.post_layernorm = RmsNorm(normalized_shape=hidden_size,
eps=rms_norm_eps,
dtype=dtype)
def forward(self,
hidden_states,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
all_reduce_workspace=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self._layer_id == 0:
self.register_network_output(f"norm0", hidden_states)
attention_output = self.attention(hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
workspace=all_reduce_workspace)
if use_cache:
attention_output, presents = attention_output
if self._layer_id == 0:
self.register_network_output(f"attn", attention_output)
hidden_states = residual + attention_output
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
if self._layer_id == 0:
self.register_network_output(f"norm1", hidden_states)
hidden_states = self.mlp(hidden_states, all_reduce_workspace)
if self._layer_id == 0:
self.register_network_output(f"mlp", hidden_states)
hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_statesLLaMAModel
模型并行在此实现。
class LLaMAModel(Module):
def __init__(self,
num_layers,
num_heads,
num_kv_heads,
hidden_size,
vocab_size,
hidden_act,
max_position_embeddings,
dtype,
mlp_hidden_size=None,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_base=10000.0,
rotary_scaling=None,
mapping=Mapping(),
quant_mode=QuantMode(0),
use_parallel_embedding=False,
embedding_sharding_dim=0,
rms_norm_eps=1e-06):
super().__init__()
self.mapping = mapping
if self.mapping.is_first_pp_rank():
self.vocab_embedding = Embedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
self.layers = ModuleList([
LLaMADecoderLayer(layer_id=i,
hidden_size=hidden_size,
num_attention_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
dtype=dtype,
hidden_act=hidden_act,
mlp_hidden_size=mlp_hidden_size,
position_embedding_type=position_embedding_type,
rotary_base=rotary_base,
rotary_scaling=rotary_scaling,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
quant_mode=quant_mode,
rms_norm_eps=rms_norm_eps)
for i in self.get_transformer_layers(self.mapping, num_layers)
])
if self.mapping.is_last_pp_rank():
self.ln_f = RmsNorm(normalized_shape=hidden_size,
eps=rms_norm_eps,
dtype=dtype)
def forward(self,
input_ids,
position_ids=None,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
all_reduce_workspace=None):
if kv_cache_params.past_key_value is None:
tuple([None] * len(self.layers))
if use_cache:
presents = []
if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(input_ids)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
self.register_network_output(f"embd", hidden_states)
for layer, past, pointer in zip(
self.layers, kv_cache_params.past_key_value,
kv_cache_params.kv_cache_block_pointers):
hidden_states = layer(
hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=KeyValueCacheParams(
past_key_value=[past],
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
kv_cache_block_pointers=[pointer],
cache_indirection=kv_cache_params.cache_indirection),
attention_params=attention_params,
all_reduce_workspace=all_reduce_workspace)
if use_cache:
presents.append(hidden_states[1])
hidden_states = hidden_states[0]
if self.mapping.is_last_pp_rank():
hidden_states = self.ln_f(hidden_states)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
GenerationMixin
class GenerationMixin:
def get_transformer_layers(self, mapping, num_layers):
# 提取pp_rank相应的层
layers_per_pipeline_stage = num_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
return layers_range
def prepare_basic_inputs(self,
max_batch_size,
max_beam_width,
max_input_len,
max_new_tokens,
num_kv_heads,
head_size,
num_layers,
kv_dtype,
remove_input_padding=False,
use_gpt_attention_plugin=False,
use_gemm_plugin=False,
use_custom_all_reduce=False,
paged_kv_cache=False,
tokens_per_block=64,
gather_all_token_logits=False,
dtype=None,
num_heads=None,
mapping=Mapping(),
max_num_tokens=None):
max_len = max_input_len + max_new_tokens
# 计算上下文阶段和生成阶段的动态范围
# 这部分为TensorRT的优化配置文件提供动态形状范围,支持不同的批处理大小和序列长度
# cxt = context, gen = geneartion
# each item is profile with [min, opt, max]
# bb = batch * beam, bs = batch size
bb_range_cxt = [1, (max_batch_size + 1) // 2, max_batch_size]
bb_range_gen = [
1, (max_batch_size * max_beam_width + 1) // 2,
max_batch_size * max_beam_width
]
_bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
_beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width]
inlen_range_cxt = [1, (max_input_len + 1) // 2, max_input_len]
inlen_range_gen = [1, 1, 1]
_mask_len_ctx = [1, (max_input_len + 1) // 2, max_input_len]
_mask_len_gen = [2, (max_len + 1) // 2 + 1, max_len + 1]
_kv_cache_range_ctx = [0, 0, 0]
_kv_cache_range_gen = [1, (max_len + 1) // 2, max_len]
_max_len_range = [0, (max_len + 1) // 2, max_len]
if max_num_tokens is None:
# 未指定 max_num_tokens
num_tokens_range_ctx = [
1, (max_input_len * max_batch_size + 1) // 2,
max_input_len * max_batch_size
] # context all tokens
num_tokens_range_gen = [
1, max_batch_size * max_beam_width,
max_beam_width * max_batch_size
] # one token
else:
num_tokens_range_ctx = [[
1, (max_num_tokens + 1) // 2, max_num_tokens
]]
num_tokens_range_gen = [[
1, (max_num_tokens + 1) // 2, max_num_tokens
]]
enable_two_optimization_profiles = False
if use_gpt_attention_plugin == False or use_gemm_plugin == False:
# in-flight batch enabled, when use_gpt_attention_plugin and remove_input_padding and paged_kv_cache
use_in_flight_batching = use_gpt_attention_plugin and remove_input_padding and paged_kv_cache
enable_two_optimization_profiles = not use_in_flight_batching
if enable_two_optimization_profiles:
# 非 in-flight batch时,优化两遍
bb_range = [bb_range_cxt, bb_range_gen]
bs_range = [_bs_range, _bs_range]
beam_width_range = [_beam_width_range, _beam_width_range]
inlen_range = [inlen_range_cxt, inlen_range_gen]
mask_len_range = [_mask_len_ctx, _mask_len_gen]
if use_gpt_attention_plugin:
kv_cache_range = [_kv_cache_range_gen, _kv_cache_range_gen]
else:
kv_cache_range = [_kv_cache_range_ctx, _kv_cache_range_gen]
max_len_range = [_max_len_range, _max_len_range]
num_tokens_range = [num_tokens_range_ctx, num_tokens_range_gen]
else:
# 一遍优化
bb_range = [bb_range_gen]
bs_range = [_bs_range]
beam_width_range = [_beam_width_range]
inlen_range = [[1, 1, max_input_len]]
mask_len_range = [[1, (max_len + 1) // 2 + 1, max_len + 1]]
kv_cache_range = [[0, (max_len + 1) // 2, max_len]]
max_len_range = [_max_len_range]
if max_num_tokens is None:
num_tokens_range = [[
1, max_batch_size * max_beam_width,
max(max_input_len * max_batch_size,
max_beam_width * max_batch_size)
]]
else:
num_tokens_range = num_tokens_range_ctx
# 输入张量配置
input_ids = None
position_ids = None
hidden_states = None
if remove_input_padding:
if mapping.is_first_pp_rank():
# pp_rank==0, inputs
input_ids = Tensor(
name='input_ids',
dtype=trt.int32,
shape=[1, -1], # [1, num_tokens] packed
dim_range=OrderedDict([
('batch_size_fake',
[1, 1] if enable_two_optimization_profiles else [1]),
('num_tokens', num_tokens_range),
]))
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[1, -1], # [1, num_tokens] packed
dim_range=OrderedDict([
('batch_size_fake',
[1, 1] if enable_two_optimization_profiles else [1]),
('num_tokens', num_tokens_range),
]))
else:
# pp_rank > 0, hidden_states
assert dtype is not None
assert num_heads is not None
hidden_states = Tensor(
name='hidden_states_input',
dtype=dtype,
shape=[1, -1, head_size * num_heads],
dim_range=OrderedDict([
('batch_size_fake',
[1, 1] if enable_two_optimization_profiles else [1]),
('num_tokens', num_tokens_range),
('hidden_size',
[head_size * num_heads, head_size *
num_heads] if enable_two_optimization_profiles else
[head_size * num_heads]),
]))
else:
if mapping.is_first_pp_rank():
input_ids = Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1, -1], # [batch_size_beam_width, sequence_length]
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('input_len', inlen_range),
]))
position_ids = Tensor(name='position_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('input_len', inlen_range),
]))
else:
assert dtype is not None
assert num_heads is not None
hidden_states = Tensor(
name='hidden_states_input',
dtype=dtype,
shape=[-1, -1, head_size * num_heads],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('input_len', inlen_range),
('hidden_size',
[head_size * num_heads, head_size *
num_heads] if enable_two_optimization_profiles else
[head_size * num_heads]),
]))
# make num_kv_heads % top_size == 0
num_kv_heads = (num_kv_heads + mapping.tp_size - 1) // mapping.tp_size
# 获取pp_rank对应的层index
layers_range = self.get_transformer_layers(mapping, num_layers)
past_key_value = [] # one item per layer, used when !paged_kv_cache
kv_cache_block_pointers_list = [] # one item per layer, used when paged_kv_cache
if not paged_kv_cache:
# linear kv cache
for i in layers_range:
kv_dim_range = OrderedDict([
('batch_size_beam_width', bb_range),
('kv', [2, 2] if enable_two_optimization_profiles else [2]),
('num_heads', [num_kv_heads, num_kv_heads]
if enable_two_optimization_profiles else [num_kv_heads]),
('past_key_len', kv_cache_range),
('head_size', [head_size, head_size]
if enable_two_optimization_profiles else [head_size]),
])
kv = Tensor(name=f'past_key_value_{i}',
dtype=kv_dtype,
shape=[-1, 2, num_kv_heads, -1, head_size],
dim_range=kv_dim_range)
past_key_value.append(kv)
kv_cache_block_pointers_list.append(None)
else:
if enable_two_optimization_profiles:
max_blocks_per_seq_range = [
[
math.ceil(kv_cache_range[0][0] / tokens_per_block),
math.ceil(kv_cache_range[0][1] / tokens_per_block),
math.ceil(kv_cache_range[0][2] / tokens_per_block)
],
[
math.ceil(kv_cache_range[1][0] / tokens_per_block),
math.ceil(kv_cache_range[1][1] / tokens_per_block),
math.ceil(kv_cache_range[1][2] / tokens_per_block)
]
]
blocks_range = [
[
bb_range[0][0] * max_blocks_per_seq_range[0][0],
bb_range[0][1] * max_blocks_per_seq_range[0][1],
bb_range[0][2] * max_blocks_per_seq_range[0][2]
],
[
bb_range[1][0] * max_blocks_per_seq_range[1][0],
bb_range[1][1] * max_blocks_per_seq_range[1][1],
bb_range[1][2] * max_blocks_per_seq_range[1][2]
],
]
max_blocks_per_seq_range = [[
x for x in max_blocks_per_seq_range[0]
], [x for x in max_blocks_per_seq_range[1]]]
else:
max_blocks_per_seq_range = [[
math.ceil(kv_cache_range[0][0] / tokens_per_block),
math.ceil(kv_cache_range[0][1] / tokens_per_block),
math.ceil(kv_cache_range[0][2] / tokens_per_block)
]]
blocks_range = [[
bb_range[0][0] * max_blocks_per_seq_range[0][0],
bb_range[0][1] * max_blocks_per_seq_range[0][1],
bb_range[0][2] * max_blocks_per_seq_range[0][2]
]]
max_blocks_per_seq_range = [[
x for x in max_blocks_per_seq_range[0]
]]
for i in layers_range:
kv_cache_block_pointers = Tensor(
name=f'kv_cache_block_pointers_{i}',
dtype=trt.int64, #
shape=[-1, 2, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('kv',
[2, 2] if enable_two_optimization_profiles else [2]),
('max_blocks_per_seq', max_blocks_per_seq_range),
]))
kv_cache_block_pointers_list.append(kv_cache_block_pointers)
past_key_value.append(None)
sequence_length = None
context_lengths = None
host_context_lengths = None
host_past_key_value_lengths = None
attention_mask = None
cache_indirection = None
host_request_types = None
if use_gpt_attention_plugin:
sequence_length = Tensor(
name='sequence_length',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
host_request_types = Tensor(
name='host_request_types',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
host_past_key_value_lengths = Tensor(
name='host_past_key_value_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
context_lengths = Tensor(
name='context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
else:
attention_mask = Tensor(
name='attention_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('mask_len', mask_len_range),
]),
)
if use_gpt_attention_plugin and remove_input_padding:
host_context_lengths = Tensor(
name='host_context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
last_token_ids = None
if mapping.is_last_pp_rank() and not gather_all_token_logits:
# 只获取最后一个token的logits时需要此输入
last_token_ids = Tensor(
name='last_token_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_last_token_ids', bb_range),
]),
)
# beamsearch backtrace map
cache_indirection = Tensor(
name='cache_indirection',
dtype=trt.int32,
shape=[-1, -1, -1],
dim_range=OrderedDict([
('batch_size_cache', bs_range),
('beam_width', beam_width_range),
('max_seq_len', max_len_range),
]),
)
all_reduce_workspace = None
if use_custom_all_reduce and mapping.tp_size > 1:
# 3 (= buffer + signals_in + signals_out)
workspace_size = 3 * mapping.tp_size
all_reduce_workspace = Tensor(
name='all_reduce_workspace',
dtype=trt.int64,
shape=[workspace_size],
dim_range=OrderedDict([
('all_reduce_size', [workspace_size, workspace_size]
if enable_two_optimization_profiles else [workspace_size])
]))
return {
# input
## pp_rank == 0
'input_ids': input_ids,
'position_ids': position_ids,
## pp_rank > 0
'hidden_states_input': hidden_states,
# !use_gpt_attention_plugin
'attention_mask': attention_mask,
# !paged_kv_cache
'past_key_value': past_key_value,
# paged_kv_cache
'kv_cache_block_pointers_list': kv_cache_block_pointers_list,
# !use_gpt_attention_plugin
'last_token_ids': last_token_ids,
# use_gpt_attention_plugin
'sequence_length': sequence_length,
'host_request_types': host_request_types,
'host_past_key_value_lengths': host_past_key_value_lengths,
'context_lengths': context_lengths,
# use_gpt_attention_plugin and remove_input_padding
'host_context_lengths': host_context_lengths,
'cache_indirection': cache_indirection,
# use_custom_all_reduce and mapping.tp_size > 1
'all_reduce_workspace': all_reduce_workspace,
}LLaMAForCausalLM
class LLaMAForCausalLM(LLaMAModel, GenerationMixin):
def __init__(self,
num_layers,
num_heads,
num_kv_heads,
hidden_size,
vocab_size,
hidden_act,
max_position_embeddings,
dtype,
logits_dtype="float32",
mlp_hidden_size=None,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_base=10000.0,
rotary_scaling=None,
mapping=Mapping(),
quant_mode=QuantMode(0),
use_parallel_embedding=False,
embedding_sharding_dim=0,
rms_norm_eps=1e-06):
if isinstance(dtype, str):
self.dtype = str_dtype_to_trt(dtype)
else:
assert isinstance(dtype, trt.DataType)
self.dtype = dtype
if isinstance(logits_dtype, str):
self.logits_dtype = str_dtype_to_trt(logits_dtype)
else:
assert isinstance(logits_dtype, trt.DataType)
self.logits_dtype = logits_dtype
self.num_layers = num_layers
self.num_heads = num_heads
if num_kv_heads is None or num_kv_heads <= 0:
num_kv_heads = num_heads
self.num_kv_heads = num_kv_heads
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.tp_size = mapping.tp_size
self.kv_dtype = self.dtype
if quant_mode.has_int8_kv_cache():
self.kv_dtype = str_dtype_to_trt('int8')
elif quant_mode.has_fp8_kv_cache():
self.kv_dtype = str_dtype_to_trt('fp8')
self.quant_mode = quant_mode
self.use_parallel_embedding = use_parallel_embedding
self.embedding_sharding_dim = embedding_sharding_dim
super().__init__(num_layers, num_heads, num_kv_heads, hidden_size,
vocab_size, hidden_act, max_position_embeddings, dtype,
mlp_hidden_size, position_embedding_type, rotary_base,
rotary_scaling, mapping, quant_mode,
use_parallel_embedding, embedding_sharding_dim,
rms_norm_eps)
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
if self.mapping.is_last_pp_rank():
self.lm_head = ColumnLinear(hidden_size,
vocab_size_padded,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True) def forward(self,
input_ids,
position_ids=None,
use_cache=False,
last_token_ids=None,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
all_reduce_workspace=None):
hidden_states = super().forward(input_ids, position_ids, use_cache,
attention_mask, kv_cache_params,
attention_params, hidden_states,
all_reduce_workspace)
if use_cache:
hidden_states, presents = hidden_states
if self.mapping.is_last_pp_rank():
# !gather_last_token_logits 输出所有token的logits
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
# [batch_size, hidden_size] -> [batch_size, vocab_size]
lm_logits = self.lm_head(hidden_states)
# 标记NetWork的输出
lm_logits.mark_output('logits', self.logits_dtype)
else:
hidden_states.mark_output('hidden_states_output', self.dtype)
if use_cache and default_net().plugin_config.paged_kv_cache == False:
for i, present in zip(
self.get_transformer_layers(self.mapping, self.num_layers),
presents):
present.mark_output(f'present_key_value_{i}', self.kv_dtype)
if self.mapping.is_last_pp_rank():
# 只有!paged_kv_cache才会返回 past_kv_cache,即中间层的 kv cache。
return (lm_logits, presents)
return (hidden_states, presents)
else:
if self.mapping.is_last_pp_rank():
# paged_kv_cache时,不会返回 kv cache。
return lm_logits
return hidden_states def prepare_inputs(self,
max_batch_size,
max_input_len,
max_new_tokens,
use_cache,
max_beam_width,
max_num_tokens: int = None):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
head_size = self.hidden_size // self.num_heads
remove_input_padding = default_net().plugin_config.remove_input_padding
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
use_gemm_plugin = default_net().plugin_config.gemm_plugin
paged_kv_cache = default_net().plugin_config.paged_kv_cache
tokens_per_block = default_net().plugin_config.tokens_per_block
use_custom_all_reduce = default_net(
).plugin_config.use_custom_all_reduce
model_inputs = self.prepare_basic_inputs(
max_batch_size,
max_beam_width,
max_input_len,
max_new_tokens,
self.num_kv_heads,
head_size,
self.num_layers,
self.kv_dtype,
remove_input_padding=remove_input_padding,
use_gpt_attention_plugin=use_gpt_attention_plugin,
use_gemm_plugin=use_gemm_plugin,
use_custom_all_reduce=use_custom_all_reduce,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
dtype=self.dtype,
num_heads=self.num_heads,
mapping=self.mapping,
max_num_tokens=max_num_tokens)
return (model_inputs['input_ids'], model_inputs['position_ids'], True,
model_inputs['last_token_ids'], model_inputs['attention_mask'],
KeyValueCacheParams(
past_key_value=model_inputs['past_key_value'],
host_past_key_value_lengths=model_inputs[
'host_past_key_value_lengths'],
kv_cache_block_pointers=model_inputs[
'kv_cache_block_pointers_list'],
cache_indirection=model_inputs['cache_indirection'],
),
AttentionParams(
sequence_length=model_inputs['sequence_length'],
context_lengths=model_inputs['context_lengths'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types']),
model_inputs['hidden_states_input'],
model_inputs['all_reduce_workspace'])参考文献
• https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/models/llama/model.py

夜雨聆风