TensorRT-LLM 0.5.0 源码之四
builder.py
class _BuildingFlag:
def __enter__(self):
os.environ['IS_BUILDING'] = '1'
def __exit__(self, type, value, tb):
del os.environ['IS_BUILDING']
def _is_building(f):
'''Use this to decorate functions which are called during engine building/refiting process,
otherwise, the plugin registration will fail.
'''
@wraps(f)
def decorated(*args, **kwargs):
with _BuildingFlag():
return f(*args, **kwargs)
return decorated
class BuilderConfig(object):
def __init__(self, **kwargs):
# intentionally use **kwargs, user should never call this ctor directly,
# use Builder.create_builder_config() instead
pass
def _init(self, trt_builder_config, **kwargs):
self._trt_builder_config = trt_builder_config
for key, value in kwargs.items():
setattr(self, key, value)
return self
@property
def trt_builder_config(self) -> trt.IBuilderConfig:
return self._trt_builder_config
class Builder():
_ALLOWED_PRECISIONS = ['float32', 'float16', 'bfloat16']
def __init__(self):
super().__init__()
self._trt_builder = trt.Builder(logger.trt_logger)
self.strongly_typed = False
@property
def trt_builder(self) -> trt.Builder:
return self._trt_builder
def create_network(self) -> Network:
if version.parse(trt_version()) >= version.parse(
"9.1.0") and self.strongly_typed:
# STRONGLY_TYPED 它要求创建一个强类型网络。在强类型网络中,每个张量和层的输入/输出都必须明确指定其数据类型(如 float32, int8),TensorRT构建器会进行更严格的类型检查,这有助于在模型构建的早期发现类型不匹配的错误,提升模型的稳健性
return Network()._init(
self.trt_builder.create_network(
(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
| (1 << int(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))))
else:
return Network()._init(
self.trt_builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)))
def create_builder_config(self,
precision: str,
timing_cache: Union[str, Path,
trt.ITimingCache] = None,
tensor_parallel: int = 1,
use_refit: bool = False,
int8: bool = False,
fp8: bool = False,
strongly_typed: bool = False,
opt_level: Optional[int] = None,
**kwargs) -> BuilderConfig:
''' @brief Create a builder config with given precisions and timing cache
@param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS
@param timing_cache: a timing cache object or a path to a timing cache file
@param tensor_parallel: number of GPUs used for tensor parallel
@param kwargs: any other arguments users would like to attach to the config object as attributes
@param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others
@param int8: whether to build with int8 enabled or not. Can't be used together with refit option
@return: A BuilderConfig object, return None if failed
'''
self.strongly_typed = strongly_typed
if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
logger.error(
f"precision should be one of {self._ALLOWED_PRECISIONS}")
if use_refit and int8:
# TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes
# These folded weights can not be refitted
logger.error(f"can't use refit and int8 mode at the same time")
config = self.trt_builder.create_builder_config()
if not strongly_typed:
if precision == 'float16':
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
elif precision == 'bfloat16':
config.set_flag(trt.BuilderFlag.BF16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
if int8:
config.set_flag(trt.BuilderFlag.INT8)
if fp8:
config.set_flag(trt.BuilderFlag.FP8)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806,
True)
if use_refit:
config.set_flag(trt.BuilderFlag.REFIT)
if opt_level is not None:
config.builder_optimization_level = opt_level
# set timing cache
cache = None
if timing_cache is not None:
# use given cache
if isinstance(timing_cache, trt.ITimingCache):
cache = timing_cache
# read cache from file
elif isinstance(timing_cache,
(str, Path)) and os.path.exists(timing_cache):
with open(timing_cache, "rb") as f:
cache = config.create_timing_cache(f.read())
else:
logger.warning(
"Invalid timing cache, using freshly created one")
if cache is None:
cache = config.create_timing_cache(b"")
# When user does not given any existing cache, internally always created one
# so the cache should never None here
assert cache is not None and isinstance(cache, trt.ITimingCache)
config.set_timing_cache(cache, ignore_mismatch=False)
return BuilderConfig()._init(config,
precision=precision,
tensor_parallel=tensor_parallel,
use_refit=use_refit,
int8=int8,
fp8=fp8,
**kwargs)
# 优化配置文件让TensorRT能够处理动态输入尺寸。在构建引擎时,TensorRT会根据你提供的形状范围提前优化内核选择、内存分配等,这样在推理时就可以在这个范围内灵活处理不同尺寸的输入,而无需重新构建引擎
def _add_optimization_profile(self, network: Network,
builder_config: BuilderConfig):
assert isinstance(builder_config, BuilderConfig)
assert isinstance(network, Network)
input_tensors = network._inputs
# 确定需要创建多少个优化配置文件(基于第一个输入张量的配置数量)
num_profiles = len(list(input_tensors.items())[0][1].profiles)
# 为每个配置文件编号循环
for i in range(num_profiles):
logger.debug(f'Adding optimization profile {i+1}/{num_profiles}')
# 1. 创建空的优化配置文件
profile = self.trt_builder.create_optimization_profile()
# 2. 为每个输入张量设置形状范围
for input_name in input_tensors.keys():
shape_profile = input_tensors[input_name].profiles[i]
# 设置该输入在三个场景下的形状尺寸
profile.set_shape(input_name, shape_profile.min,
shape_profile.opt, shape_profile.max)
logger.debug(
f'{input_name}, min: {shape_profile.min}, opt: {shape_profile.opt}, max: {shape_profile.max}, dimension names: {shape_profile.dimension_names}'
)
# 3. 将配置好的profile添加到构建器配置中
builder_config.trt_builder_config.add_optimization_profile(profile)
# 4. 验证维度配置的正确性
assert self._validate_named_dimensions(
network, builder_config
), "Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log"
# 在TensorRT中,优化配置文件通过set_shape方法为每个输入张量定义最小、最优、最大形状范围。如果多个输入张量共享同一个命名维度(如batch_size),那么该维度在所有张量中的范围必须一致,否则会导致运行时错误
def _validate_named_dimensions(self, network: Network,
builder_config) -> bool:
'''
For each profile, validate that the named dimensions of different input tensors in this profile all have same range.
TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT-LLM are correct and
makes the error msg more user friendly.
'''
valid = True
for profile_idx in range(
builder_config.trt_builder_config.num_optimization_profiles):
# 建立维度名称到范围值的映射
dimension_to_range = {}
for input_name, input_tensor in network._inputs.items():
# it's legal that a Tensor does not have dim_range?
if len(input_tensor.profiles) != 0:
profile = input_tensor.profiles[profile_idx]
for dim_idx, dim_name in enumerate(profile.dimension_names):
if dim_name not in dimension_to_range:
dimension_to_range[dim_name] = []
# 记录每个维度的范围信息
min, opt, max = profile.min[dim_idx], profile.opt[
dim_idx], profile.max[dim_idx]
dimension_to_range[dim_name].append(
(input_name, (min, opt, max)))
# 验证同一维度的范围是否一致
for dim, ranges in dimension_to_range.items():
unique_ranges = set([r[1] for r in ranges])
logger.debug(
f"Validating dimension:{dim}, ranges for this dim are:{unique_ranges}"
)
if len(unique_ranges) != 1: # 如果存在不一致的范围
logger.error(
f"Found illegal dimension setting for profile {profile_idx}, dimension name is: {dim}"
)
logger.error(
f"Offensive tensors which have this dimension are:\n" +
"\n".join([f"{r[1]} {dim} {r[0]}" for r in ranges]))
valid = False
return valid
# 这是一个用于重构(Refit)TensorRT引擎的Python函数,其核心功能是在不重新构建整个引擎的情况下,更新引擎中的权重参数。
# 主要目标:利用网络(network)中的新权重数据,快速重构一个已序列化的TensorRT引擎(engine_buffer)。这避免了重新构建引擎的开销,尤其适用于权重频繁更新的场景。
# 前提条件:原引擎必须已启用REFIT标志(构建时设置),且network的结构需与引擎一致
@_is_building
def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory:
'''
@brief: Refit one TensorRT engine using weights from the network,
user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.
@param engine_buffer: A serialized TensorRT engine.
@param network: Network object.
@return: A serialized TRT engine if refit successfully, None otherwise
'''
assert isinstance(network, Network)
logger.info(f'Refit TRT engine')
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(engine_buffer)
tik = time.time()
# Refit engine
refitter = trt.Refitter(engine, logger.trt_logger)
# 设置新权重:
if network.named_parameters is not None:
for name, param in network.named_parameters:
if param._get_weights(
) is None or not refitter.set_named_weights(
name, param._get_weights()):
logger.error(f'Failed to refit weight: {name}')
return None
else:
logger.error(
f'Please set named parameters before building multiple engines.'
)
return None
# 将新权重应用到引擎中
if not refitter.refit_cuda_engine():
logger.error(f'Failed to refit engine.')
return None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of refitting {engine.name}: {t}')
serialized_engine = engine.serialize()
return serialized_engine
@_is_building
def build_engine(self, network: Network,
builder_config: BuilderConfig) -> trt.IHostMemory:
'''
@brief: Build one TensorRT engine from the network.
@param network: Network object.
@param builder_config: BuilderConfig object.
@return: A serialized TRT engine.
'''
assert isinstance(network, Network)
builder_config.plugin_config = network.plugin_config
self._add_optimization_profile(network, builder_config)
engine = None
logger.info(f'Build TensorRT engine {network.trt_network.name}')
tik = time.time()
# Rename weights
if network.named_parameters is not None:
for name, param in network.named_parameters:
if param._get_weights(
) is None or not network.trt_network.set_weights_name(
param._get_weights(), name):
raise RuntimeError(f'Failed to set weight: {name}')
# Build engine
engine = self.trt_builder.build_serialized_network(
network.trt_network, builder_config.trt_builder_config)
if engine is None:
logger.error('Engine building failed, please check the error log.')
return None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building {network.trt_network.name}: {t}')
return engine
@staticmethod
def save_timing_cache(builder_config: BuilderConfig, out_path: str) -> bool:
'''Serialize timing cache of given builder config to file specified by out_path
return True if the cache is successfully serialized, False otherwise
'''
cache = builder_config.trt_builder_config.get_timing_cache()
if cache is None:
logger.warning(
'No timing cache found in the given builder config, skip saving.'
)
return False
with cache.serialize() as buffer:
with open(out_path, "wb") as f:
f.write(buffer)
f.flush()
os.fsync(f)
logger.info(f'Timing cache serialized to {out_path}')
return True
@staticmethod
def save_config(builder_config: BuilderConfig, config_path: str):
config = {'builder_config': {}}
for k in builder_config.__dict__.keys():
if k != '_trt_builder_config' and k != 'plugin_config':
config['builder_config'][k] = builder_config.__getattribute__(k)
config['plugin_config'] = to_dict(builder_config.plugin_config)
to_json_file(config, config_path)
logger.info(f'Config saved to {config_path}.')
参考文献
-
• https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/builder.py

夜雨聆风
