乐于分享
好东西不私藏

TensorRT-LLM 0.5.0 源码之四

TensorRT-LLM 0.5.0 源码之四

builder.py

class _BuildingFlag:

    def
 __enter__(self):
        os.environ['IS_BUILDING'] = '1'

    def
 __exit__(self, type, value, tb):
        del
 os.environ['IS_BUILDING']


def
 _is_building(f):
    '''Use this to decorate functions which are called during engine building/refiting process,
    otherwise, the plugin registration will fail.
    '''


    @wraps(f)

    def
 decorated(*args, **kwargs):
        with
 _BuildingFlag():
            return
 f(*args, **kwargs)

    return
 decorated

class
 BuilderConfig(object):

    def
 __init__(self, **kwargs):
        # intentionally use **kwargs, user should never call this ctor directly,

        # use Builder.create_builder_config() instead

        pass


    def
 _init(self, trt_builder_config, **kwargs):
        self
._trt_builder_config = trt_builder_config
        for
 key, value in kwargs.items():
            setattr
(self, key, value)
        return
 self

    @property

    def
 trt_builder_config(self) -> trt.IBuilderConfig:
        return
 self._trt_builder_config
class Builder():

    _ALLOWED_PRECISIONS = ['float32', 'float16', 'bfloat16']

    def
 __init__(self):
        super
().__init__()
        self
._trt_builder = trt.Builder(logger.trt_logger)
        self
.strongly_typed = False

    @property

    def
 trt_builder(self) -> trt.Builder:
        return
 self._trt_builder
    def create_network(self) -> Network:
        if
 version.parse(trt_version()) >= version.parse(
                "9.1.0"
) and self.strongly_typed:
            # STRONGLY_TYPED 它要求创建一个强类型网络。在强类型网络中,每个张量和层的输入/输出都必须明确指定其数据类型(如 float32, int8),TensorRT构建器会进行更严格的类型检查,这有助于在模型构建的早期发现类型不匹配的错误,提升模型的稳健性

            return
 Network()._init(
                self
.trt_builder.create_network(
                    (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
                    | (1 << int(
                        trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))))
        else
:
            return
 Network()._init(
                self
.trt_builder.create_network(
                    1
 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)))
    def create_builder_config(self,
                              precision: str,
                              timing_cache: Union[str, Path,
                                                  trt.ITimingCache] = None,
                              tensor_parallel: int = 1,
                              use_refit: bool = False,
                              int8: bool = False,
                              fp8: bool = False,
                              strongly_typed: bool = False,
                              opt_level: Optional[int] = None,
                              **kwargs
) -> BuilderConfig:
        ''' @brief Create a builder config with given precisions and timing cache
            @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS
            @param timing_cache: a timing cache object or a path to a timing cache file
            @param tensor_parallel: number of GPUs used for tensor parallel
            @param kwargs: any other arguments users would like to attach to the config object as attributes
            @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others
            @param int8: whether to build with int8 enabled or not. Can't be used together with refit option
            @return: A BuilderConfig object, return None if failed
        '''

        self
.strongly_typed = strongly_typed

        if
 not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
            logger.error(
                f"precision should be one of {self._ALLOWED_PRECISIONS}"
)

        if
 use_refit and int8:
            # TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes

            # These folded weights can not be refitted

            logger.error(f"can't use refit and int8 mode at the same time")

        config = self.trt_builder.create_builder_config()
        if
 not strongly_typed:
            if
 precision == 'float16':
                config.set_flag(trt.BuilderFlag.FP16)
                config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
            elif
 precision == 'bfloat16':
                config.set_flag(trt.BuilderFlag.BF16)
                config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
            if
 int8:
                config.set_flag(trt.BuilderFlag.INT8)

            if
 fp8:
                config.set_flag(trt.BuilderFlag.FP8)
                config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)

        config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806,
                                   True
)

        if
 use_refit:
            config.set_flag(trt.BuilderFlag.REFIT)

        if
 opt_level is not None:
            config.builder_optimization_level = opt_level

        # set timing cache

        cache = None
        if
 timing_cache is not None:
            # use given cache

            if
 isinstance(timing_cache, trt.ITimingCache):
                cache = timing_cache
            # read cache from file

            elif
 isinstance(timing_cache,
                            (str, Path)) and os.path.exists(timing_cache):
                with
 open(timing_cache, "rb") as f:
                    cache = config.create_timing_cache(f.read())
            else
:
                logger.warning(
                    "Invalid timing cache, using freshly created one"
)
        if
 cache is None:
            cache = config.create_timing_cache(b"")
        # When user does not given any existing cache, internally always created one

        # so the cache should never None here

        assert
 cache is not None and isinstance(cache, trt.ITimingCache)
        config.set_timing_cache(cache, ignore_mismatch=False)

        return
 BuilderConfig()._init(config,
                                     precision=precision,
                                     tensor_parallel=tensor_parallel,
                                     use_refit=use_refit,
                                     int8=int8,
                                     fp8=fp8,
                                     **kwargs)
    # 优化配置文件让TensorRT能够处理动态输入尺寸。在构建引擎时,TensorRT会根据你提供的形状范围提前优化内核选择、内存分配等,这样在推理时就可以在这个范围内灵活处理不同尺寸的输入,而无需重新构建引擎

    def
 _add_optimization_profile(self, network: Network,
                                  builder_config: BuilderConfig
):
        assert
 isinstance(builder_config, BuilderConfig)
        assert
 isinstance(network, Network)
        input_tensors = network._inputs

        # 确定需要创建多少个优化配置文件(基于第一个输入张量的配置数量)

        num_profiles = len(list(input_tensors.items())[0][1].profiles)
             # 为每个配置文件编号循环

        for
 i in range(num_profiles):
            logger.debug(f'Adding optimization profile {i+1}/{num_profiles}')
            # 1. 创建空的优化配置文件

            profile = self.trt_builder.create_optimization_profile()
                    # 2. 为每个输入张量设置形状范围

            for
 input_name in input_tensors.keys():
                shape_profile = input_tensors[input_name].profiles[i]
                # 设置该输入在三个场景下的形状尺寸

                profile.set_shape(input_name, shape_profile.min,
                                  shape_profile.opt, shape_profile.max)
                logger.debug(
                    f'{input_name}, min: {shape_profile.min}, opt: {shape_profile.opt}, max: {shape_profile.max}, dimension names: {shape_profile.dimension_names}'

                )
            # 3. 将配置好的profile添加到构建器配置中

            builder_config.trt_builder_config.add_optimization_profile(profile)
        # 4. 验证维度配置的正确性

        assert
 self._validate_named_dimensions(
            network, builder_config
        ), "Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log"
    # 在TensorRT中,优化配置文件通过set_shape方法为每个输入张量定义最小、最优、最大形状范围。如果多个输入张量共享同一个命名维度(如batch_size),那么该维度在所有张量中的范围必须一致,否则会导致运行时错误
    def
 _validate_named_dimensions(self, network: Network,
                                   builder_config
) -> bool:
        '''
            For each profile, validate that the named dimensions of different input tensors in this profile all have same range.
            TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT-LLM are correct and
            makes the error msg more user friendly.
        '''

        valid = True
        for
 profile_idx in range(
                builder_config.trt_builder_config.num_optimization_profiles):
            # 建立维度名称到范围值的映射

            dimension_to_range = {}
            for
 input_name, input_tensor in network._inputs.items():
                # it's legal that a Tensor does not have dim_range?

                if
 len(input_tensor.profiles) != 0:
                    profile = input_tensor.profiles[profile_idx]
                    for
 dim_idx, dim_name in enumerate(profile.dimension_names):
                        if
 dim_name not in dimension_to_range:
                            dimension_to_range[dim_name] = []
                        # 记录每个维度的范围信息

                        min
, opt, max = profile.min[dim_idx], profile.opt[
                            dim_idx], profile.max[dim_idx]
                        dimension_to_range[dim_name].append(
                            (input_name, (min, opt, max)))
            # 验证同一维度的范围是否一致

            for
 dim, ranges in dimension_to_range.items():
                unique_ranges = set([r[1] for r in ranges])
                logger.debug(
                    f"Validating dimension:{dim}, ranges for this dim are:{unique_ranges}"

                )
                if
 len(unique_ranges) != 1: # 如果存在不一致的范围
                    logger.error(
                        f"Found illegal dimension setting for profile {profile_idx}, dimension name is: {dim}"

                    )
                    logger.error(
                        f"Offensive tensors which have this dimension are:\n"
 +
                        "\n"
.join([f"{r[1]} {dim} {r[0]}" for r in ranges]))
                    valid = False
        return
 valid
    # 这是一个用于重构(Refit)TensorRT引擎的Python函数,其核心功能是在不重新构建整个引擎的情况下,更新引擎中的权重参数。
    # 主要目标:利用网络(network)中的新权重数据,快速重构一个已序列化的TensorRT引擎(engine_buffer)。这避免了重新构建引擎的开销,尤其适用于权重频繁更新的场景。

    # 前提条件:原引擎必须已启用REFIT标志(构建时设置),且network的结构需与引擎一致


    @_is_building

    def
 refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory:
        '''
            @brief: Refit one TensorRT engine using weights from the network,
                user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.
            @param engine_buffer: A serialized TensorRT engine.
            @param network: Network object.
            @return: A serialized TRT engine if refit successfully, None otherwise
        '''

        assert
 isinstance(network, Network)
        logger.info(f'Refit TRT engine')
        runtime = trt.Runtime(logger.trt_logger)
        engine = runtime.deserialize_cuda_engine(engine_buffer)

        tik = time.time()

        # Refit engine

        refitter = trt.Refitter(engine, logger.trt_logger)
        # 设置新权重:

        if
 network.named_parameters is not None:
            for
 name, param in network.named_parameters:
                if
 param._get_weights(
                ) is None or not refitter.set_named_weights(
                        name, param._get_weights()):
                    logger.error(f'Failed to refit weight: {name}')
                    return
 None
        else
:
            logger.error(
                f'Please set named parameters before building multiple engines.'

            )
            return
 None
        # 将新权重应用到引擎中

        if
 not refitter.refit_cuda_engine():
            logger.error(f'Failed to refit engine.')
            return
 None

        tok = time.time()
        t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
        logger.info(f'Total time of refitting {engine.name}: {t}')
        serialized_engine = engine.serialize()
        return
 serialized_engine
    @_is_building
    def
 build_engine(self, network: Network,
                     builder_config: BuilderConfig
) -> trt.IHostMemory:
        '''
            @brief: Build one TensorRT engine from the network.
            @param network: Network object.
            @param builder_config: BuilderConfig object.
            @return: A serialized TRT engine.
        '''

        assert
 isinstance(network, Network)
        builder_config.plugin_config = network.plugin_config
        self
._add_optimization_profile(network, builder_config)
        engine = None
        logger.info(f'Build TensorRT engine {network.trt_network.name}')
        tik = time.time()

        # Rename weights

        if
 network.named_parameters is not None:
            for
 name, param in network.named_parameters:
                if
 param._get_weights(
                ) is None or not network.trt_network.set_weights_name(
                        param._get_weights(), name):
                    raise
 RuntimeError(f'Failed to set weight: {name}')

        # Build engine

        engine = self.trt_builder.build_serialized_network(
            network.trt_network, builder_config.trt_builder_config)
        if
 engine is None:
            logger.error('Engine building failed, please check the error log.')
            return
 None

        tok = time.time()
        t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
        logger.info(f'Total time of building {network.trt_network.name}: {t}')

        return
 engine
    @staticmethod
    def
 save_timing_cache(builder_config: BuilderConfig, out_path: str) -> bool:
        '''Serialize timing cache of given builder config to file specified by out_path
            return True if the cache is successfully serialized, False otherwise
        '''

        cache = builder_config.trt_builder_config.get_timing_cache()
        if
 cache is None:
            logger.warning(
                'No timing cache found in the given builder config, skip saving.'

            )
            return
 False
        with
 cache.serialize() as buffer:
            with
 open(out_path, "wb") as f:
                f.write(buffer)
                f.flush()
                os.fsync(f)
        logger.info(f'Timing cache serialized to {out_path}')
        return
 True
    @staticmethod
    def
 save_config(builder_config: BuilderConfig, config_path: str):
        config = {'builder_config': {}}
        for
 k in builder_config.__dict__.keys():
            if
 k != '_trt_builder_config' and k != 'plugin_config':
                config['builder_config'][k] = builder_config.__getattribute__(k)
        config['plugin_config'] = to_dict(builder_config.plugin_config)
        to_json_file(config, config_path)
        logger.info(f'Config saved to {config_path}.')

参考文献

  • • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/builder.py
点个「赞」+「在看」❤️
让我们知道这份文字有温暖到你,也是我们持续创作的最大动力!
推荐
Share Memory 的 Bank Conflict
告别高成本!TensorRT-LLM实战:如何将LLM推理速度提升数倍
使用LoRA对LLM进行微调的实用技巧
强化学习小白必看:PTX Loss 到底是个啥?
GPT-5 Prompt Migration and Improvement Using the New Optimizer
Task 异步流 coroutine 实现
C++ corotine 介绍
搭建 VSCode 离线开发环境
nlohmann/json 库简介
Intro to C++ Coroutines: Concept
Hugging Face BPE Tokenizer 的资源文件
移动语义 std::move 和完美转发 std::forward
ACEBench: Who Wins the Match Point in Tool Usage?
什么是 GN
RULER: Relative Universal LLM-Elicited Rewards
SFT和RFT的区别
CosyVoice 3: 面向真实场景的大规模零样本语音生成模型
CosyVoice 3: Towards In-the-wild Speech Generation
语音合成(TTS)中文自然度:问题、成因、解决方案
上下文工程如何实现
上下文工程(Context Engineering)
新手必看!LangGraph 101:手把手教你搭一个深度研究 Agent
LangGraph 简介
SFT 泛化新解读:强化学习 + 奖励修正,一文读懂
程序员狂喜!Self-Instruct 框架全解析:无限生成高质量指令集,从此告别标注噩梦!
Evol-Instruct 竟能精准生成领域专属数据?实操技巧速看!
指令微调数据-少即是多
LLM generate 参数怎么用?
语音合成(TTS)跳跃与重复问题的解析:成因、机制及解决方案
大模型训练新思路:GEPA 靠 “反思” 赢过 RL,看完秒懂
F5-TTS:用 Flow Matching 玩转语音,流畅度和真实感都 “拉满” 了
E2 TTS:令人尴尬地简单、完全非自回归、零样本的语音合成技术
Voicebox:大规模文本引导的多语言通用语音生成技术
为什么都在聊 Kimi K2?Open Agentic Intelligence 藏着哪些新惊喜
Step-Audio-AQAA 端到端音频模型
DPO、PPO、GRPO的原理,区别与联系
OPENCSG 中文语料库:一系列高质量的中文数据集,用于语言模型训练
什么是 Classifier-Free Guidance?
Conditional Flow Matching : 连续标准流 Continuous Normalizing Flow
CFM 与 OT-CFM:条件流匹配与最优传输的碰撞
DPO损失实现
Conditional Flow Matching : 常微分方程ODE、欧拉方法和Neural ODE
当 Normalizing flow 遇上语音生成:AI 说话变 “真人” 的秘密在这里!
深度剖析:Kimi – Audio 中 BigVGAN 的神奇作用
为什么说分布变换是 Normalizing flow 的「灵魂操作」?
MATCHA-TTS 来了!条件流匹配让文本转语音效率飙升
从知识增长的角度提升RAG上下文的质量
MiniMax-Speech,零样本语音合成新突破,32 种语言轻松拿捏!
手把手教你创建 evol-instruct 数据集!附完整流程~
社交类聊天的 Query 分析与应答策略
SFT 中指令选择和响应选择哪个更重要?
角色扮演大模型技术分享2-超拟人模型的困境
最新!SpeechLLM 综述:架构、能力、挑战与未来全揭秘
如何低成本生成高质量指令微调数据?
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
Kimi-Audio:开源音频基础模型全面解析
Kimi-Audio 的 TTS 效果如何?
Qwen 的训练数据是怎么做的?
GeForce RTX 3090, 4090, A10, A40, A100, A800, L20, L40 显卡性能对比
如何低成本生成高质量指令微调数据?
掌握RAG:投入生产前要评估的8个场景
掌握RAG:如何评估RAG的LLM
掌握RAG:如何在部署后观察您的RAG
掌握RAG:如何选择嵌入模型
基础模型中的新范式:为什么o1是不同的,以及它将如何改变LLM应用
Semantic token和连续特征在SLLM下的对比
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
RLHF及其变体:进展和实际工程见解
Freeze-Omni: 低延迟语音对话模型
Fully Sharded Data Parallelism (FSDP)
什么是置信度?置信度模型怎么做?
晦涩难懂的 Flow matching!图形化理解
中文指令微调数据,质量就是一切!
基于 LLM 的文本泛化
CosyVoice 2:基于大型语言模型的可扩展流式语音合成技术
Mini-Omni2: with Vision, Speech and Duplex Capabilities
FSQ的原理与VQ-VAE的区别和联系
大模型并行训练的一些知识——极简版
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
要用 AI 裁员 50% 的千亿独角兽,公开认错,重启招聘!
single codebook和dual codebook在LLM中向量量化上有什么区别?
一些文档去重算法
最佳的指令数据应当是什么样的?
Prefill-Decode分离
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
Simhash-文档去重算法简介
RLHF 入门,高手勿进!
最佳的指令数据应当是什么样的?
CosyVoice:一种基于监督式语义标记的可扩展多语言 Zero-Shot 语音合成器
Model Context Protocol (MCP)
MCP(模型上下文协议)是什么以及它是如何运作的
压力测试LLMs——大海捞针实现
本站文章均为手工撰写未经允许谢绝转载:夜雨聆风 » TensorRT-LLM 0.5.0 源码之四

评论 抢沙发

5 + 4 =
  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
×
订阅图标按钮