乐于分享
好东西不私藏

TensorRT-LLM 0.5.0 源码之六

TensorRT-LLM 0.5.0 源码之六

network.py

_UniqueNameGenerator

# name1 = generator('UserService', 'com.moduleA')  # 返回 'com/moduleA/UserService_0'
# name2 = generator('UserService', 'com.moduleB')  # 返回 'com/moduleB/UserService_0'


# file_generator = _UniqueNameGenerator('file_')

# unique_filename = file_generator('image', 'uploads')  # 返回 'file_uploads/image_0'


# session_generator = _UniqueNameGenerator('session_')

# session_id = session_generator('user', 'webapp')  # 返回 'session_webapp/user_0'


class
 _UniqueNameGenerator(object):

    def
 __init__(self, prefix=''):
        self
.ids = collections.defaultdict(int)
        self
.prefix = prefix

    def
 __call__(self, key, module_name=''):
        if
 module_name != '':
            module_name = module_name.replace(".", "/")
            key = module_name + '/' + key
        tmp = self.ids[key]
        self
.ids[key] += 1
        return
 f"{self.prefix}{key}_{tmp}"

net_guard

@contextlib.contextmanager
def
 net_guard(network):
    from
 ._common import net
    assert
 isinstance(
        network, Network
    ), f"Invalid network, can only guard Network instance, got: {network}"

    old_net = net
    set_network(network)
    yield

    set_network(old_net)

_TrtLlmModuleCallStack

_TrtLlmModuleCallStack 类是一个用于在模型执行过程中动态追踪和记录模块调用栈的工具。它在深度学习框架(尤其是像TensorRT-LLM这样复杂的推理引擎)的调试、性能分析或内部状态监控中非常有用。

这个类的主要目标是提供一个轻量级的机制,来回答“当前代码正在哪个模块中执行?”这个问题。它通过维护一个运行时的调用栈来实现这一点,每当进入一个模块时,将其名称压入栈中,离开时弹出。


class
 _TrtLlmModuleCallStack(object):
    # 这是类的核心,作为一个栈数据结构来使用。它动态地记录着当前执行路径上经过的模块名称序列。栈顶元素(call_stack[-1])就代表了当前正在执行的模块。

    call_stack = []
    # 这个字典充当一个模块名称的注册表。它的键是模块对象本身,值是该对象的完整名称(例如,"model.transformer.layers.0.attention")。

    module_name_map = {}

    def
 __init__(self):
        super
().__init__()
        self
.mod_names_set = False

    def
 module_names_set(self):
        return
 self.mod_names_set

    def
 set_module_names(self, top_level_module):
        assert
 top_level_module, "Expected a top level module"
        for
 name, mod in top_level_module.named_modules(
                prefix=top_level_module._get_name()): # 遍历所有子模块
            if
 mod not in self.module_name_map:
                self
.module_name_map[mod] = name # 注册模块对象到名称的映射
        self
.mod_names_set = True
        return


    def
 get_current_module(self):
        mod_name = ''
        if
 len(self.call_stack):
            mod_name = self.call_stack[-1]
        return
 mod_name

    def
 get_mod_name(self, mod_obj):
        name = ''
        if
 mod_obj in self.module_name_map:
            name = self.module_name_map[mod_obj]
        return
 name

    def
 get_stack(self):
        return
 self.call_stack

    @contextlib.contextmanager

    def
 call_stack_mgr(self):
        call_stack = self.get_stack()
        try
:
            yield
 call_stack # 在此处执行带有模块名的压栈操作
        finally
:
            call_stack.pop() # 无论块内代码是否异常,最终都会执行弹栈

Network

class Network(object):

    def
 __init__(self, **kwargs):
        # intentionally use **kwargs, user should never call this ctor directly

        # use Builder.create_network() instead


        # Holds the removed layers and disable them in graph rewritings and other phases.

        # This is a hacky way since INetwork python API doesn't provide a way to remove a layer.

        # TODO: remove this when TensorRT provides a better way to remove a layer

        self
._removed_layers: Set[str] = set()

        self
.is_graph_altered = False

        from
 .graph_rewriting import FLayerInfoMemo
        self
.flayer_memo = FLayerInfoMemo()  # holds the functional metadata
    def _init(self, trt_network):
        self
._trt_network = trt_network
        self
._inputs = {}
        self
._named_parameters = None
        # layer precision of a given scope, this is used together with precision(dtype) context manager

        self
._dtype = None
        self
._name_generator = _UniqueNameGenerator()
        self
._plugin_config = PluginConfig()
        self
._module_call_stack = _TrtLlmModuleCallStack()
        self
._registered_ndarrays = []
        self
._strongly_typed = trt.INetworkDefinition.get_flag(
            self
._trt_network, trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)

        return
 self

    @property

    def
 dtype(self) -> trt.DataType:
        return
 self._dtype

    @dtype.setter

    def
 dtype(self, dtype: trt.DataType):
        assert
 isinstance(dtype, trt.DataType) or dtype is None
        self
._dtype = dtype

    @property

    def
 trt_network(self) -> trt.INetworkDefinition:
        return
 self._trt_network

    @property

    def
 plugin_config(self) -> PluginConfig:
        return
 self._plugin_config

    @property

    def
 strongly_typed(self) -> bool:
        return
 self._strongly_typed
    def _add_input(self,
                   tensor,
                   name,
                   dtype,
                   shape,
                   dim_range: OrderedDict = None
):
        assert
 isinstance(dtype, trt.DataType)
        tensor.trt_tensor = self.trt_network.add_input(
            name=name,
            shape=shape,
            dtype=dtype,
        )
        if
 dim_range is not None:
            logger.debug(
                f'Add input: {name}, shape: {shape}, dtype: {dtype}, dimension names:{list(dim_range.keys())}'

            )
            for
 i, dim_name in enumerate(dim_range.keys()):
                tensor.trt_tensor.set_dimension_name(i, str(dim_name))
        else
:
            logger.debug(f'Add input: {name}, shape: {shape}, dtype: {dtype}')
        self
._inputs[name] = tensor
    def _mark_output(self, tensor, name, dtype):
        from
 .functional import cast

        if
 self.strongly_typed:
            if
 tensor.trt_tensor.dtype != dtype:
                # If stronglyTyped mode is enabled and inferred output dtype does not match desired dtype, add a cast.

                cast_output = cast(tensor, dtype)
                self
.trt_network.mark_output(cast_output.trt_tensor)
                cast_output.trt_tensor.name = name
            else
:
                # Otherwise, mark the tensor as network output. We should not set tensor dtype in stronglyTyped mode.

                self
.trt_network.mark_output(tensor.trt_tensor)
                tensor.trt_tensor.name = name
        else
:
            self
.trt_network.mark_output(tensor.trt_tensor)
            tensor.trt_tensor.name = name
            tensor.trt_tensor.dtype = dtype
        logger.debug(f'Mark output: {name}, dtype: {dtype}')
    def set_named_parameters(self, named_parameters):
        self
._named_parameters = named_parameters

    @property

    def
 named_parameters(self):
        return
 self._named_parameters
    def _set_layer_name(self, layer):
        layer_name = str(layer.type).split('.')[-1]
        current_module = self._module_call_stack.get_current_module()

        if
 layer.type == trt.LayerType.PLUGIN_V2:
            layer_name = '_'.join(
                [layer_name,
                 str
(layer.plugin.plugin_type).split('.')[-1]])
        elif
 layer.type in [
                trt.LayerType.UNARY, trt.LayerType.REDUCE,
                trt.LayerType.ELEMENTWISE
        ]:
            layer_name = '_'.join([layer_name, str(layer.op).split('.')[-1]])

        layer.name = self._name_generator(layer_name, current_module)
        for
 idx in range(layer.num_outputs):
            # TRT initializes tensor names from the initial layer's name when the layer is created,

            # and does not update tensor names when layer name changed by application, needs to

            # change the tensor name to align with the new layer name for better debugging

            layer.get_output(idx).name = f"{layer.name}_output_{idx}"
    def register_ndarray(self, ndarray: np.ndarray) -> None:
        self
._registered_ndarrays.append(ndarray)
    def get_inputs(self):
        '''
        Get the inputs of the network.

        Returns:
            Iterable[Tensor]
        '''

        return
 self._inputs.values()

    def
 get_outputs(self):
        '''
        Get the outputs of the network.

        Returns:
            Iterable[Tensor]
        '''

        from
 .functional import Tensor
        for
 i in range(self._trt_network.num_outputs):
            tensor = self._trt_network.get_output(i)
            yield
 Tensor(trt_tensor=tensor,
                         network=self,
                         is_network_input=False)
    def is_input(self, tensor) -> bool:
        '''
        Tell if a tensor is a input of the network.

        Parameters:
            tensor: Union[Tensor, str, trt.ITensor]
        '''

        from
 .functional import Tensor

        if
 isinstance(tensor, str):
            tensor_name = tensor
        elif
 isinstance(tensor, (trt.ITensor, Tensor)):
            tensor_name = tensor.name
        else
:
            raise
 ValueError(
                f"tensor should be Tensor, str or ITensor, got {tensor}"
)

        return
 self._inputs.get(tensor_name, False)

    def
 is_output(self, tensor) -> bool:
        '''
        Tell if a tensor is a output of the network.

        Parameters:
            tensor: Tensor
        '''

        for
 i in range(self._trt_network.num_outputs):
            if
 tensor.trt_tensor is self._trt_network.get_output(i):
                return
 True
        return
 False
    def get_layers(self) -> Iterable["Layer"]:
        '''
        Get all the layers of network.

        Returns:
            Iterable[Layer]
        '''

        from
 .graph_rewriting import Layer
        for
 i in range(self._trt_network.num_layers):
            layer = Layer(network=self,
                          trt_layer=self._trt_network.get_layer(i))
            yield
 layer

    def
 get_layer_by_name(self, name: str) -> Optional["Layer"]:
        state = self._get_graph()
        return
 state.name_to_layer.get(name, None)
    def get_tensor_users(self, tensor) -> Iterable["Layer"]:
        '''
        Get the layers those consumes this tensor.
        '''

        state = self._get_graph()
        for
 layer in state.tensor_to_consumers[tensor]:
            yield
 layer

    def
 get_tensor_parent(self, tensor) -> Optional["Layer"]:
        '''
        Get the layer that produces this tensor.
        '''

        state = self._get_graph()
        return
 state.tensor_to_producer.get(tensor, None)
    def mark_removed_layer(self, layer: "Layer"):
        from
 .graph_rewriting import FLayerInfoMemo
        self
._removed_layers.add(layer.name)

        # Try to delete the layer if it is a Plugin

        FLayerInfoMemo.instance().remove(layer.name)

    def
 is_removed_layer(self, layer: "Layer") -> bool:
        return
 layer.name in self._removed_layers

    @property

    def
 removed_layers(self) -> Iterable["Layer"]:
        for
 layer_name in self._removed_layers:
            layer = self.get_layer_by_name(layer_name)
            assert
 layer, "Invalid layer name"
            yield
 layer
    def _get_graph(self) -> "Network._GraphState":
        '''
        Get the graph of the network.

        Returns:
            Network._GraphState
        '''

        return
 self._get_graph_impl(self._get_network_hash())

    @lru_cache(maxsize=1)

    def
 _get_graph_impl(self, network_hash: bytes) -> "Network._GraphState":
        graph = Network._GraphState()
        graph.build(self)
        return
 graph

    @dataclass

    class
 _GraphState:
        # Tensor to Layers

        tensor_to_consumers: Dict[Any, List["Layer"]] = field(
            default_factory=lambda: defaultdict(list))
        # Tensor to Layer

        tensor_to_producer: Dict[Any, "Layer"] = field(default_factory=dict)
        inputs: Dict[str, Any] = field(default_factory=OrderedDict)
        outputs: Dict[str, Any] = field(default_factory=OrderedDict)
        name_to_layer: Dict[str, "Layer"] = field(default_factory=dict)

        def
 build(self, network: "Network") -> None:
            from
 .graph_rewriting import Layer
            self
.inputs = network.get_inputs()
            self
.outputs = network.get_outputs()

            for
 layer in network.get_layers():
                self
.name_to_layer[layer.name] = Layer(
                    network=network, trt_layer=layer.trt_layer)
                for
 i in range(layer.num_inputs):
                    input_tensor = layer.get_inputs(i)[0]
                    if
 input_tensor.is_trt_wrapper():
                        self
.tensor_to_consumers[input_tensor].append(layer)
                for
 i in range(layer.num_outputs):
                    output_tensor = layer.get_outputs(i)[0]
                    if
 output_tensor.is_trt_wrapper():
                        self
.tensor_to_producer[output_tensor] = layer

    def
 _get_network_hash(self, lightweight=True) -> bytes:
        # TODO: Ask TensorRT team to add a hash function for INetworkDefinition instead of using this hacky way

        num_layers = self.trt_network.num_layers

        # Some special layers, such as slice, may be associated with tensors that do not have the `trt_tensor` member.

        get_tensor_tag = lambda tensor: tensor.trt_tensor.name if tensor.is_trt_wrapper(
        ) else 'None'

        if
 lightweight and not self.is_graph_altered:
            return
 num_layers
        self
.is_graph_altered = False

        data = hashlib.sha256()
        # network layer count

        data.update(str(num_layers).encode())
        # network inputs

        data.update(','.join(
            [get_tensor_tag(tensor) for tensor in self.get_inputs()]).encode())
        # network outputs

        data.update(','.join(
            [get_tensor_tag(tensor) for tensor in self.get_outputs()]).encode())
        # layer names

        data.update(','.join(
            [layer.trt_layer.name for layer in self.get_layers()]).encode())

        # layer -> output

        data.update(','.join([
            f'{layer.trt_layer.name}->{get_tensor_tag(tensor)}'

            for
 layer in self.get_layers() for tensor in layer.get_outputs()
        ]).encode())

        # input -> layer

        data.update(','.join([
            f'{get_tensor_tag(tensor)}->{layer.trt_layer.name}'

            for
 layer in self.get_layers() for tensor in layer.get_inputs()
        ]).encode())

        return
 data.hexdigest()

参考文献

  • • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/network.py
点个「赞」+「在看」❤️
让我们知道这份文字有温暖到你,也是我们持续创作的最大动力!
推荐
Share Memory 的 Bank Conflict
告别高成本!TensorRT-LLM实战:如何将LLM推理速度提升数倍
使用LoRA对LLM进行微调的实用技巧
强化学习小白必看:PTX Loss 到底是个啥?
GPT-5 Prompt Migration and Improvement Using the New Optimizer
Task 异步流 coroutine 实现
C++ corotine 介绍
搭建 VSCode 离线开发环境
nlohmann/json 库简介
Intro to C++ Coroutines: Concept
Hugging Face BPE Tokenizer 的资源文件
移动语义 std::move 和完美转发 std::forward
ACEBench: Who Wins the Match Point in Tool Usage?
什么是 GN
RULER: Relative Universal LLM-Elicited Rewards
SFT和RFT的区别
CosyVoice 3: 面向真实场景的大规模零样本语音生成模型
CosyVoice 3: Towards In-the-wild Speech Generation
语音合成(TTS)中文自然度:问题、成因、解决方案
上下文工程如何实现
上下文工程(Context Engineering)
新手必看!LangGraph 101:手把手教你搭一个深度研究 Agent
LangGraph 简介
SFT 泛化新解读:强化学习 + 奖励修正,一文读懂
程序员狂喜!Self-Instruct 框架全解析:无限生成高质量指令集,从此告别标注噩梦!
Evol-Instruct 竟能精准生成领域专属数据?实操技巧速看!
指令微调数据-少即是多
LLM generate 参数怎么用?
语音合成(TTS)跳跃与重复问题的解析:成因、机制及解决方案
大模型训练新思路:GEPA 靠 “反思” 赢过 RL,看完秒懂
F5-TTS:用 Flow Matching 玩转语音,流畅度和真实感都 “拉满” 了
E2 TTS:令人尴尬地简单、完全非自回归、零样本的语音合成技术
Voicebox:大规模文本引导的多语言通用语音生成技术
为什么都在聊 Kimi K2?Open Agentic Intelligence 藏着哪些新惊喜
Step-Audio-AQAA 端到端音频模型
DPO、PPO、GRPO的原理,区别与联系
OPENCSG 中文语料库:一系列高质量的中文数据集,用于语言模型训练
什么是 Classifier-Free Guidance?
Conditional Flow Matching : 连续标准流 Continuous Normalizing Flow
CFM 与 OT-CFM:条件流匹配与最优传输的碰撞
DPO损失实现
Conditional Flow Matching : 常微分方程ODE、欧拉方法和Neural ODE
当 Normalizing flow 遇上语音生成:AI 说话变 “真人” 的秘密在这里!
深度剖析:Kimi – Audio 中 BigVGAN 的神奇作用
为什么说分布变换是 Normalizing flow 的「灵魂操作」?
MATCHA-TTS 来了!条件流匹配让文本转语音效率飙升
从知识增长的角度提升RAG上下文的质量
MiniMax-Speech,零样本语音合成新突破,32 种语言轻松拿捏!
手把手教你创建 evol-instruct 数据集!附完整流程~
社交类聊天的 Query 分析与应答策略
SFT 中指令选择和响应选择哪个更重要?
角色扮演大模型技术分享2-超拟人模型的困境
最新!SpeechLLM 综述:架构、能力、挑战与未来全揭秘
如何低成本生成高质量指令微调数据?
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
Kimi-Audio:开源音频基础模型全面解析
Kimi-Audio 的 TTS 效果如何?
Qwen 的训练数据是怎么做的?
GeForce RTX 3090, 4090, A10, A40, A100, A800, L20, L40 显卡性能对比
如何低成本生成高质量指令微调数据?
掌握RAG:投入生产前要评估的8个场景
掌握RAG:如何评估RAG的LLM
掌握RAG:如何在部署后观察您的RAG
掌握RAG:如何选择嵌入模型
基础模型中的新范式:为什么o1是不同的,以及它将如何改变LLM应用
Semantic token和连续特征在SLLM下的对比
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
RLHF及其变体:进展和实际工程见解
Freeze-Omni: 低延迟语音对话模型
Fully Sharded Data Parallelism (FSDP)
什么是置信度?置信度模型怎么做?
晦涩难懂的 Flow matching!图形化理解
中文指令微调数据,质量就是一切!
基于 LLM 的文本泛化
CosyVoice 2:基于大型语言模型的可扩展流式语音合成技术
Mini-Omni2: with Vision, Speech and Duplex Capabilities
FSQ的原理与VQ-VAE的区别和联系
大模型并行训练的一些知识——极简版
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
要用 AI 裁员 50% 的千亿独角兽,公开认错,重启招聘!
single codebook和dual codebook在LLM中向量量化上有什么区别?
一些文档去重算法
最佳的指令数据应当是什么样的?
Prefill-Decode分离
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
Simhash-文档去重算法简介
RLHF 入门,高手勿进!
最佳的指令数据应当是什么样的?
CosyVoice:一种基于监督式语义标记的可扩展多语言 Zero-Shot 语音合成器
Model Context Protocol (MCP)
MCP(模型上下文协议)是什么以及它是如何运作的
压力测试LLMs——大海捞针实现
本站文章均为手工撰写未经允许谢绝转载:夜雨聆风 » TensorRT-LLM 0.5.0 源码之六

猜你喜欢

  • 暂无文章