ammo.py
try: import ammo.torch.quantization as atq from ammo.torch.export import export_model_configexcept ImportError: raise ImportError("AMMO toolkit is not installed. Please install it first.")
def _quantize_model(model: torch.nn.Module, qformat: Literal['fp8', 'int8_sq', 'int4_awq'], calib_dataloader: DataLoader, quant_cfg_dict: Optional[Dict] = None) -> torch.nn.Module: assert qformat in ['fp8', 'int8_sq', 'int4_awq'], \ f'Got unsupported AMMO quantization format, {qformat} ' if qformat == "fp8": quant_cfg = atq.FP8_DEFAULT_CFG if quant_cfg_dict: for name, cfg in quant_cfg_dict.items(): quant_cfg['quant_cfg'][name] = cfg elif qformat == "int8_sq": quant_cfg = atq.INT8_SMOOTHQUANT_CFG elif qformat == "int4_awq": quant_cfg = atq.INT4_AWQ_CFG else: raise ValueError(f"Unsupported quantization format: {qformat}") def calibrate_loop(): """Adjusts weights and scaling factors based on selected algorithms.""" for idx, data in enumerate(calib_dataloader): logger.debug(f"Calibrating batch {idx}") model(data) logger.debug("Starting quantization...") atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) logger.debug("Quantization done") return model
def quantize_and_export(model: torch.nn.Module, qformat: Literal['fp8', 'int8_sq', 'int4_awq'], calib_dataloader: DataLoader, export_path: Optional[Union[str, Path]] = None, tensor_parallel_size: int = 1) -> torch.nn.Module: model_cls_name = type(model).__name__ if "Llama" in model_cls_name: model_type = "llama" elif "GPTJ" in model_cls_name: model_type = "gptj" elif "GPT2" in model_cls_name: model_type = "gpt2" elif "Falcon" in model_cls_name or "RW" in model_cls_name: model_type = "falcon" else: raise NotImplementedError( f"Deploying quantized model {model_cls_name} is not supported") model = _quantize_model(model, qformat=qformat, calib_dataloader=calib_dataloader) if export_path: with torch.inference_mode(): if qformat == "int4_awq": torch.save(model.state_dict(), export_path) else: export_model_config( model, model_type, torch.float16, quantization=qformat, export_dir=export_path, inference_tensor_parallel=tensor_parallel_size, ) logger.info(f"Quantized model exported to :{export_path}") return model
quant.py
# isort: offfrom ...quantization.layers import ( SmoothQuantAttention, SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP, SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear, WeightOnlyGroupwiseQuantRowLinear, WeightOnlyQuantColumnLinear, WeightOnlyQuantRowLinear)# isort: on
smooth_quantize
def _smooth_quantize_llama(model, quant_mode): assert quant_mode.has_act_and_weight_quant() for layer in model.layers: assert hasattr(layer, "input_layernorm"), "The layer has no input_layernorm" layer.input_layernorm = SmoothQuantRmsNorm( normalized_shape=layer.hidden_size, dtype=layer.dtype, quant_mode=quant_mode) assert hasattr(layer, "attention"), "The layer has no attention" layer.attention = SmoothQuantAttention( layer.hidden_size, num_attention_heads=layer.num_attention_heads, num_kv_heads=layer.num_kv_heads, max_position_embeddings=layer.max_position_embeddings, num_layers=model.num_layers, dtype=layer.dtype, attention_mask_type=layer.attention_mask_type, position_embedding_type=layer.position_embedding_type, tp_group=layer.tp_group, tp_size=layer.tp_size, quant_mode=quant_mode, bias=False) assert hasattr(layer, "mlp"), "The layer has no mlp" layer.mlp = SmoothQuantGatedMLP(hidden_size=model.hidden_size, ffn_hidden_size=layer.mlp_hidden_size, hidden_act=layer.hidden_act, dtype=layer.dtype, tp_group=layer.tp_group, tp_size=layer.tp_size, quant_mode=quant_mode, bias=False) assert hasattr( layer, "post_layernorm"), "The layer has no post_rmspost_layernormnorm" layer.post_layernorm = SmoothQuantRmsNorm( normalized_shape=layer.hidden_size, dtype=layer.dtype, quant_mode=quant_mode) setattr(model, 'quant_mode', quant_mode) return modeldef smooth_quantize(model, quant_mode): assert isinstance(model, GPTLMHeadModel) or isinstance(model, LLaMAForCausalLM) \ or isinstance(model, BloomForCausalLM),\ "Only GPTLMHeadModel, LLaMAForCausalLM and BloomForCausalLM are well tested now" if isinstance(model, LLaMAForCausalLM): return _smooth_quantize_llama(model, quant_mode) else: assert False, f"Model {type(model).__name__} is not supported by SmoothQuant yet"
weight_only_quantize
def weight_only_quantize(model, quant_mode, exclude_modules=None, current_key_name=None): assert quant_mode.is_weight_only() exclude_modules = ['lm_head' ] if exclude_modules is None else exclude_modules for name, module in model.named_children(): if current_key_name is None: current_key_name = [] current_key_name.append(name) if len(list(module.children())) > 0: weight_only_quantize(module, quant_mode, exclude_modules, current_key_name) if isinstance(module, ColumnLinear) and name not in exclude_modules: if not any(key in '.'.join(current_key_name) for key in exclude_modules): model._modules[name] = WeightOnlyQuantColumnLinear( in_features=module.in_features, out_features=module.out_features * module.tp_size, bias=module.bias is not None, dtype=module.dtype, tp_group=module.tp_group, tp_size=module.tp_size, gather_output=module.gather_output, quant_mode=quant_mode) elif isinstance(module, RowLinear) and name not in exclude_modules: if not any(key in '.'.join(current_key_name) for key in exclude_modules): model._modules[name] = WeightOnlyQuantRowLinear( in_features=module.in_features * module.tp_size, out_features=module.out_features, bias=module.bias is not None, dtype=module.dtype, tp_group=module.tp_group, tp_size=module.tp_size, quant_mode=quant_mode) current_key_name.pop(-1) setattr(model, 'quant_mode', quant_mode) return model
weight_only_groupwise_quantize
def weight_only_groupwise_quantize(model, quant_mode, group_size=128, pre_quant_scale=False, zero=False, exclude_modules=None, current_key_name=None): exclude_modules = ['lm_head' ] if exclude_modules is None else exclude_modules for name, module in model.named_children(): if current_key_name is None: current_key_name = [] current_key_name.append(name) if len(list(module.children())) > 0: weight_only_groupwise_quantize(module, quant_mode, group_size, pre_quant_scale, zero, exclude_modules, current_key_name) if isinstance(module, ColumnLinear) and name not in exclude_modules: if not any(key in '.'.join(current_key_name) for key in exclude_modules): model._modules[name] = WeightOnlyGroupwiseQuantColumnLinear( in_features=module.in_features, out_features=module.out_features * module.tp_size, group_size=group_size, pre_quant_scale=pre_quant_scale, zero=zero, bias=module.bias is not None, dtype=module.dtype, tp_group=module.tp_group, tp_size=module.tp_size, gather_output=module.gather_output) elif isinstance(module, RowLinear) and name not in exclude_modules: if not any(key in '.'.join(current_key_name) for key in exclude_modules): model._modules[name] = WeightOnlyGroupwiseQuantRowLinear( in_features=module.in_features * module.tp_size, out_features=module.out_features, group_size=group_size, pre_quant_scale=pre_quant_scale, zero=zero, bias=module.bias is not None, dtype=module.dtype, tp_group=module.tp_group, tp_size=module.tp_size) current_key_name.pop(-1) setattr(model, 'quant_mode', quant_mode) return model
others
def get_dummy_quant_scales(num_layers): return { 'lm_head_act': 0.99, 'lm_head_weights': 0.99, 'fc_act': [0.99 for _ in range(num_layers)], 'fc_weights': [0.99 for _ in range(num_layers)], 'gate_act': [0.99 for _ in range(num_layers)], 'gate_weights': [0.99 for _ in range(num_layers)], 'proj_act': [0.99 for _ in range(num_layers)], 'proj_weights': [0.99 for _ in range(num_layers)], 'qkv_act': [0.99 for _ in range(num_layers)], 'qkv_weights': [0.99 for _ in range(num_layers)], 'qkv_output': [5.0 for _ in range(num_layers)], 'dense_act': [0.99 for _ in range(num_layers)], 'dense_weights': [0.99 for _ in range(num_layers)], }
def _quantize_layer(layer, layer_idx, quant_mode, quant_scales): assert hasattr(layer, "mlp"), "The layer has no mlp" fake_fp8_sf_dt = np.float32 assert isinstance(layer.mlp.fc, (FP8Linear, FP8RowLinear)) assert isinstance(layer.mlp.proj, (FP8Linear, FP8RowLinear)) layer.mlp.fc.activation_scaling_factor.value = np.array( [quant_scales['fc_act'][layer_idx]], dtype=fake_fp8_sf_dt) layer.mlp.fc.weights_scaling_factor.value = np.array( [quant_scales['fc_weights'][layer_idx]], dtype=fake_fp8_sf_dt) layer.mlp.proj.activation_scaling_factor.value = np.array( [quant_scales['proj_act'][layer_idx]], dtype=fake_fp8_sf_dt) layer.mlp.proj.weights_scaling_factor.value = np.array( [quant_scales['proj_weights'][layer_idx]], dtype=fake_fp8_sf_dt) if hasattr(layer.mlp, 'gate'): assert isinstance(layer.mlp.gate, (FP8Linear, FP8RowLinear)) layer.mlp.gate.activation_scaling_factor.value = np.array( [quant_scales['gate_act'][layer_idx]], dtype=fake_fp8_sf_dt) layer.mlp.gate.weights_scaling_factor.value = np.array( [quant_scales['gate_weights'][layer_idx]], dtype=fake_fp8_sf_dt) assert hasattr(layer, "attention"), "The layer has no attention" assert isinstance(layer.attention.qkv, (FP8Linear, FP8RowLinear)) assert isinstance(layer.attention.dense, (FP8Linear, FP8RowLinear)) layer.attention.qkv.activation_scaling_factor.value = np.array( [quant_scales['qkv_act'][layer_idx]], dtype=fake_fp8_sf_dt) layer.attention.qkv.weights_scaling_factor.value = np.array( [quant_scales['qkv_weights'][layer_idx]], dtype=fake_fp8_sf_dt) if quant_mode.has_fp8_kv_cache(): layer.attention.kv_orig_quant_scale.value = np.array( [quant_scales['qkv_output'][layer_idx]], dtype=fake_fp8_sf_dt) layer.attention.kv_quant_orig_scale.value = np.array( [1.0 / quant_scales['qkv_output'][layer_idx]], dtype=fake_fp8_sf_dt) layer.attention.dense.activation_scaling_factor.value = np.array( [quant_scales['dense_act'][layer_idx]], dtype=fake_fp8_sf_dt) layer.attention.dense.weights_scaling_factor.value = np.array( [quant_scales['dense_weights'][layer_idx]], dtype=fake_fp8_sf_dt) return layerdef _default_fp8_quantize(model: Union[GPTLMHeadModel, LLaMAForCausalLM, GPTJForCausalLM], quant_mode: QuantMode, quant_scales: dict = None): """ Quantize all linear layers (i.e., MLP, Attention QKV/Dense) and KV cache IO with dummy scales This is used by benchmark script and therefore is intentionally decoupled from AMMO toolkit """ if quant_scales is None: num_layers = getattr(model, '_num_layers', getattr(model, 'num_layers', None)) assert num_layers is not None quant_scales = get_dummy_quant_scales(num_layers) assert model.quant_mode == quant_mode, "Quant setting not consistent with model init setting" use_fp8_qdq = quant_mode.has_fp8_qdq() assert use_fp8_qdq for layer_idx, layer in enumerate(model.layers): layer = _quantize_layer(layer, layer_idx, quant_mode, quant_scales) # TODO: add lm_head return modeldef fp8_quantize(model, quant_mode: QuantMode, quant_scales: dict = None): if isinstance( model, (FalconForCausalLM, GPTJForCausalLM, GPTLMHeadModel, LLaMAForCausalLM)): return _default_fp8_quantize(model, quant_mode, quant_scales) raise NotImplementedError( f"Model {model} is not implemented by fp8_quantize yet")
参考文献
- • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/models/quantized/ammo.py
- • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/models/quantized/quant.py