TensorRT-LLM 0.5.0 源码之三
_common.py
net = None # Newwork()_inited = False
def _init(log_level=None): global _inited if _inited: return _inited = True # Move to __init__ if log_level is not None: logger.set_level(log_level) # load plugin lib _load_plugin_lib() # load FT decoder layer project_dir = str(Path(__file__).parent.absolute()) if platform.system() == "Windows": ft_decoder_lib = project_dir + '/libs/th_common.dll' else: ft_decoder_lib = project_dir + '/libs/libth_common.so' if ft_decoder_lib == '': raise ImportError('FT decoder layer is unavailable') torch.classes.load_library(ft_decoder_lib) global net logger.info('TensorRT-LLM inited.')
def default_net(): assert net, "Use builder to create network first, and use `set_network` or `net_guard` to set it to default" return netdef default_trtnet(): return default_net().trt_networkdef set_network(network): global net net = network
def switch_net_dtype(cur_dtype): prev_dtype = default_net().dtype default_net().dtype = cur_dtype return prev_dtype@contextlib.contextmanagerdef precision(dtype): if isinstance(dtype, str): dtype = str_dtype_to_trt(dtype) prev_dtype = switch_net_dtype(dtype) yield switch_net_dtype(prev_dtype)
init.py
_init(log_level="error")
_utils.py
fp32_array = partial(np.array, dtype=np.float32)fp16_array = partial(np.array, dtype=np.float16)int32_array = partial(np.array, dtype=np.int32)
# numpy doesn't know bfloat16, define abstract binary type insteadnp_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"})def torch_to_numpy(x): if x.dtype != torch.bfloat16: return x.numpy() return x.view(torch.int16).numpy().view(np_bfloat16)
def trt_version(): return trt.__version__def torch_version(): return torch.__version__
_str_to_np_dict = dict( float16=np.float16, float32=np.float32, int32=np.int32, bfloat16=np_bfloat16,)def str_dtype_to_np(dtype): ret = _str_to_np_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret_str_to_torch_dtype_dict = dict( bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int32=torch.int32, int8=torch.int8,)def str_dtype_to_torch(dtype): ret = _str_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret_str_to_trt_dtype_dict = dict(float16=trt.float16, float32=trt.float32, int64=trt.int64, int32=trt.int32, int8=trt.int8, bool=trt.bool, bfloat16=trt.bfloat16, fp8=trt.fp8)def str_dtype_to_trt(dtype): ret = _str_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret_np_to_trt_dtype_dict = { np.int8: trt.int8, np.int32: trt.int32, np.float16: trt.float16, np.float32: trt.float32, # hash of np.dtype('int32') != np.int32 np.dtype('int8'): trt.int8, np.dtype('int32'): trt.int32, np.dtype('float16'): trt.float16, np.dtype('float32'): trt.float32,}def np_dtype_to_trt(dtype): if trt_version() >= '7.0' and dtype == np.bool_: return trt.bool if trt_version() >= '9.0' and dtype == np_bfloat16: return trt.bfloat16 ret = _np_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret_trt_to_np_dtype_dict = { trt.int8: np.int8, trt.int32: np.int32, trt.float16: np.float16, trt.float32: np.float32, trt.bool: np.bool_,}def trt_dtype_to_np(dtype): if trt_version() >= '9.0' and dtype == trt.bfloat16: return np_bfloat16 ret = _trt_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret_torch_to_np_dtype_dict = { torch.float16: np.float16, torch.float32: np.float32,}def torch_dtype_to_np(dtype): ret = _torch_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret_trt_to_torch_dtype_dict = { trt.float16: torch.float16, trt.float32: torch.float32, trt.int32: torch.int32, trt.int8: torch.int8,}def trt_dtype_to_torch(dtype): if trt_version() >= '9.0' and dtype == trt.bfloat16: return torch.bfloat16 ret = _trt_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret
def dim_to_trt_axes(dim): """Converts torch dim, or tuple of dims to a tensorrt axes bitmask""" if not isinstance(dim, tuple): dim = (dim, ) # create axes bitmask for reduce layer axes = 0 for d in dim: axes |= 1 << d return axes
def dim_resolve_negative(dim, ndim): if not isinstance(dim, tuple): dim = (dim, ) pos = [] for d in dim: if d < 0: d = ndim + d pos.append(d) return tuple(pos)
def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() if isinstance(engine, trt.ICudaEngine): engine = engine.serialize() with open(path, 'wb') as f: f.write(bytearray(engine)) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}')def deserialize_engine(path): runtime = trt.Runtime(logger.trt_logger) with open(path, 'rb') as f: logger.info(f'Loading engine from {path}...') tik = time.time() engine = runtime.deserialize_cuda_engine(f.read()) assert engine is not None tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine loaded. Total time: {t}') return engine
def mpi_comm(): from mpi4py import MPI return MPI.COMM_WORLDdef mpi_rank(): return mpi_comm().Get_rank()def mpi_world_size(): return mpi_comm().Get_size()
# 向上取整def pad_vocab_size(vocab_size, tp_size): return int(math.ceil(vocab_size / tp_size) * tp_size)
def to_dict(obj): return copy.deepcopy(obj.__dict__)def to_json_string(obj): if not isinstance(obj, dict): obj = to_dict(obj) return json.dumps(obj, indent=2, sort_keys=True) + "\n"def to_json_file(obj, json_file_path): with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(to_json_string(obj))
_field_dtype_to_np_dtype_dict = { trt.PluginFieldType.FLOAT16: np.float16, trt.PluginFieldType.FLOAT32: np.float32, trt.PluginFieldType.FLOAT64: np.float64, trt.PluginFieldType.INT8: np.int8, trt.PluginFieldType.INT16: np.int16, trt.PluginFieldType.INT32: np.int32,}def field_dtype_to_np_dtype(dtype): ret = _field_dtype_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return retdef convert_capsule_to_void_p(capsule): # 这个函数负责从Capsule对象中提取原始的C指针 ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ ctypes.py_object, ctypes.c_char_p ] # Capsule对象:是Python中用于封装C指针的容器,常见于C扩展模块中 # Python C API函数,用于从Capsule提取指针 # 第二个参数传入None,表示不检查Capsule的名称标识 return ctypes.pythonapi.PyCapsule_GetPointer(capsule, None)def get_nparray_from_void_p(void_pointer, elem_size, field_dtype): # 设置PyMemoryView_FromMemory的函数签名 ctypes.pythonapi.PyMemoryView_FromMemory.restype = ctypes.py_object ctypes.pythonapi.PyMemoryView_FromMemory.argtypes = [ ctypes.c_char_p, ctypes.c_ssize_t, ctypes.c_int ] logger.info( f'get_nparray: pointer = {void_pointer}, elem_size = {elem_size}') # 将void指针转换为char指针(字节指针) char_pointer = ctypes.cast(void_pointer, ctypes.POINTER(ctypes.c_char)) # 计算数据类型大小和总字节数 np_dtype = field_dtype_to_np_dtype(field_dtype) buf_bytes = elem_size * np.dtype(np_dtype).itemsize logger.info(f'get_nparray: buf_bytes = {buf_bytes}') # 创建memoryview对象(共享内存,无拷贝) mem_view = ctypes.pythonapi.PyMemoryView_FromMemory( char_pointer, buf_bytes, 0) # number 0 represents PyBUF_READ logger.info( f'get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}') # 从memoryview创建NumPy数组 buf = np.frombuffer(mem_view, np_dtype) return bufdef get_scalar_from_field(field): void_p = convert_capsule_to_void_p(field.data) # 提取指针 np_array = get_nparray_from_void_p(void_p, 1, field.type) # 创建1元素数组 return np_array[0] # 返回标量值
参考文献
-
• https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/_common.py -
• https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/__init__.py -
• https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/plugin/plugin.py -
• https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/_utils.py

夜雨聆风
