| import sys |
|
|
| MIN_PYTHON_VERSION = (3, 7) |
|
|
| if sys.version_info < MIN_PYTHON_VERSION: |
| raise ImportError("This script requires Python 3.7 or higher!") |
|
|
| import argparse |
| import os |
| from dataclasses import dataclass, field |
| from typing import Dict, Tuple |
| from enum import Enum, auto |
|
|
| import numpy as np |
| import onnx |
| from onnx import helper |
|
|
| BITS_TO_NUMPY_TYPE = {8: np.int8, 16: np.int16} |
|
|
|
|
| SUPPORTED_OPS = {"Conv", "Gemm", "MatMul"} |
|
|
| ONNX_OPSET = 21 |
|
|
|
|
| class WeightCategory(Enum): |
| INITIALIZER = auto() |
| CONSTANT = auto() |
| NONE = auto() |
|
|
|
|
| @dataclass |
| class BlockQuantizeConfig: |
| input_model_path: str |
| output_model_path: str |
| block_size: int |
| bits: int |
| verbose: bool |
|
|
|
|
| @dataclass |
| class BlockQuantizeResult: |
| quantized_weights: np.ndarray = field(default_factory=lambda: np.array([])) |
| scales: np.ndarray = field(default_factory=lambda: np.array([])) |
| zero_point: np.ndarray = field(default_factory=lambda: np.array([])) |
| block_size: int = 1 |
| axis: int = 1 |
| original_shape: Tuple = field(default_factory=tuple) |
| quantization_error: np.ndarray = field(default_factory=lambda: np.array([])) |
|
|
|
|
| def closest_divisor(number: int, divisor: int) -> int: |
| for d in range(divisor, 0, -1): |
| if number % d == 0: |
| return d |
| return 1 |
|
|
|
|
| def block_dequantize_tensor( |
| x: np.ndarray, block_axis: int, scale: np.ndarray, zero_point: np.ndarray |
| ) -> np.ndarray: |
| repeats = x.shape[block_axis] // scale.shape[block_axis] |
|
|
| x_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis) |
| x_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis) |
|
|
| y = ( |
| x.astype(np.float32) - x_zero_point_elementwise.astype(np.float32) |
| ) * x_scale_elementwise |
|
|
| return y |
|
|
|
|
| def block_quantize_tensor( |
| x: np.ndarray, |
| block_axis: int, |
| scale: np.ndarray, |
| zero_point: np.ndarray, |
| n_bits: int, |
| ) -> np.ndarray: |
| repeats = x.shape[block_axis] // scale.shape[block_axis] |
|
|
| y_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis) |
| y_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis) |
|
|
| type_info = np.iinfo(BITS_TO_NUMPY_TYPE[n_bits]) |
| min_value = type_info.min |
| max_value = type_info.max |
|
|
| y = np.rint(x / y_scale_elementwise + y_zero_point_elementwise) |
| y = np.clip(y, min_value, max_value) |
| y = y.astype(BITS_TO_NUMPY_TYPE[n_bits]) |
|
|
| return y |
|
|
|
|
| def create_dequantize_node( |
| node_name, |
| quantized_weights, |
| scales, |
| zero_point, |
| dequantized_weights, |
| block_size, |
| axis, |
| ) -> onnx.NodeProto: |
| block_size_attr = helper.make_attribute("block_size", block_size) |
| axis_attr = helper.make_attribute("axis", axis) |
|
|
| n = helper.make_node( |
| "DequantizeLinear", |
| inputs=[quantized_weights, scales, zero_point], |
| outputs=[dequantized_weights], |
| name=node_name, |
| ) |
| n.attribute.extend([block_size_attr, axis_attr]) |
| return n |
|
|
|
|
| def create_reshape_node( |
| node_name, dequantized_weights, shape_tensor, reshaped_weights_name |
| ) -> onnx.NodeProto: |
| return helper.make_node( |
| "Reshape", |
| inputs=[dequantized_weights, shape_tensor], |
| outputs=[reshaped_weights_name], |
| name=node_name, |
| ) |
|
|
|
|
| class BlockQuantizer: |
| def __init__(self, conf: BlockQuantizeConfig) -> None: |
| self.conf = conf |
| self.validate_conf() |
|
|
| self.model = onnx.load(conf.input_model_path) |
|
|
| if self.model.opset_import[0].version != ONNX_OPSET: |
| self.model = onnx.version_converter.convert_version(self.model, ONNX_OPSET) |
|
|
| self.graph = self.model.graph |
| self.initializers_map = { |
| init.name: init for init in self.model.graph.initializer |
| } |
| self.costants_map = { |
| node.output[0]: next( |
| attr.t for attr in node.attribute if attr.name == "value" |
| ) |
| for node in self.model.graph.node |
| if node.op_type == "Constant" |
| } |
|
|
| def validate_conf(self): |
| if not os.path.isfile(self.conf.input_model_path): |
| raise ValueError( |
| f"Input model path '{self.conf.input_model_path}' does not exist or is not a file." |
| ) |
|
|
| if not self.conf.input_model_path.lower().endswith(".onnx"): |
| raise ValueError( |
| f"Input model path '{self.conf.input_model_path}' must have a .onnx extension." |
| ) |
|
|
| if not self.conf.output_model_path.lower().endswith(".onnx"): |
| raise ValueError( |
| f"Output model path '{self.conf.output_model_path}' must have a .onnx extension." |
| ) |
|
|
| if self.conf.block_size <= 0: |
| raise ValueError("Block size must be a positive integer.") |
|
|
| if self.conf.bits not in BITS_TO_NUMPY_TYPE: |
| allowed_values = ", ".join([str(k) for k in BITS_TO_NUMPY_TYPE.keys()]) |
| raise ValueError( |
| f"Bits must be one of the following values: [{allowed_values}]." |
| ) |
|
|
| def get_weight_category(self, name: str) -> WeightCategory: |
| if name in self.initializers_map: |
| return WeightCategory.INITIALIZER |
| if name in self.costants_map: |
| return WeightCategory.CONSTANT |
| else: |
| return WeightCategory.NONE |
|
|
| def get_weight_tensor(self, name: str, category: WeightCategory) -> np.ndarray: |
| if category == WeightCategory.INITIALIZER: |
| return onnx.numpy_helper.to_array(self.initializers_map[name]) |
| elif category == WeightCategory.CONSTANT: |
| return onnx.numpy_helper.to_array(self.costants_map[name]) |
| else: |
| raise AssertionError("Invalid weight category") |
|
|
| def remove_fp32_weights(self, name: str, category: WeightCategory): |
| if category == WeightCategory.INITIALIZER: |
| self.graph.initializer.remove( |
| next(init for init in self.graph.initializer if init.name == name) |
| ) |
| elif category == WeightCategory.CONSTANT: |
| self.graph.node.remove( |
| next( |
| node |
| for node in self.graph.node |
| if node.op_type == "Constant" and node.output[0] == name |
| ) |
| ) |
| else: |
| raise AssertionError("Invalid weight category") |
|
|
| def compute_scale_zeropoint( |
| self, b_min: np.ndarray, b_max: np.ndarray |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| assert ( |
| b_min <= b_max |
| ).all(), "minimum must not be greater than maximum when computing scale and zero point" |
|
|
| |
| b_min = np.minimum(b_min, np.zeros_like(b_min, dtype=b_min.dtype)) |
| b_max = np.maximum(b_max, np.zeros_like(b_max, dtype=b_max.dtype)) |
|
|
| type_info = np.iinfo(BITS_TO_NUMPY_TYPE[self.conf.bits]) |
| qmin = type_info.min |
| qmax = type_info.max |
|
|
| dq = qmax - qmin |
|
|
| scales = np.where(b_max != b_min, (b_max - b_min) / dq, 1.0) |
|
|
| zeropoints = np.where(b_max != b_min, np.rint(qmin - b_min / scales), 0.0) |
| zeropoints = zeropoints.astype(BITS_TO_NUMPY_TYPE[self.conf.bits]) |
|
|
| return (scales, zeropoints) |
|
|
| def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult: |
| original_shape = weight.shape |
|
|
| if weight.ndim > 1: |
| weight = weight.reshape((weight.shape[0], -1)) |
| quantization_axis = 1 |
| else: |
| quantization_axis = 0 |
|
|
| block_size = closest_divisor( |
| weight.shape[quantization_axis], self.conf.block_size |
| ) |
|
|
| assert ( |
| weight.shape[quantization_axis] % block_size == 0 |
| ), f"weight shape ({weight.shape[quantization_axis]}) must be divisible by block size ({block_size})" |
|
|
| |
| new_shape = list(weight.shape[: quantization_axis + 1]) + [-1] |
| new_shape[quantization_axis] = new_shape[quantization_axis] // block_size |
|
|
| blocked_weight = weight.reshape(new_shape) |
|
|
| blocked_max = np.max(blocked_weight, -1) |
| blocked_min = np.min(blocked_weight, -1) |
|
|
| scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max) |
|
|
| quantized_weight = block_quantize_tensor( |
| weight, quantization_axis, scales, zeropoints, self.conf.bits |
| ) |
| reconstructed_mat = block_dequantize_tensor( |
| quantized_weight, quantization_axis, scales, zeropoints |
| ) |
|
|
| |
| qerror = np.linalg.norm(reconstructed_mat - weight) / (np.linalg.norm(weight) + 1e-10) |
|
|
| res = BlockQuantizeResult( |
| quantized_weight, |
| scales, |
| zeropoints, |
| block_size, |
| quantization_axis, |
| original_shape, |
| qerror, |
| ) |
|
|
| return res |
|
|
| def get_model_size(self, model_path: str) -> float: |
| size_bytes = os.path.getsize(model_path) |
| size_mb = size_bytes / 1024 |
|
|
| return size_mb |
|
|
| def display_summary(self, sqe: Dict[str, int]): |
| sqe_v = list(sqe.values()) |
| if len(sqe_v) == 0: |
| mse = 0 |
| print( |
| "Warning: No weights have been quantized, likely due to unsupported layers." |
| ) |
| else: |
| mse = sum(sqe_v) / len(sqe_v) |
| original_model_size = self.get_model_size(self.conf.input_model_path) |
| quantized_model_size = self.get_model_size(self.conf.output_model_path) |
|
|
| if self.conf.verbose: |
| sorted_sqe = sorted(sqe.items(), key=lambda item: item[1], reverse=True) |
| longest_key_len = max(len(key) for key in sqe.keys()) |
| |
| print("Quantization error (Relative Norm) sorted in ascending order:") |
|
|
| for key, value in sorted_sqe: |
| print(f"{key:<{longest_key_len}} : {value}") |
|
|
| print("Done! Results saved in", self.conf.output_model_path) |
| print("\nSummary of Results:\n") |
| print(f"{'Metric':<30} {'Value':<10}") |
| print(f"{'-'*40}") |
| print(f"{'Relative Norm Error':<31} {mse:.6f}") |
| print(f"{'Original Model Size (KB)':<31} {original_model_size:,.2f}") |
| print(f"{'Block-Quantized Model Size (KB)':<30} {quantized_model_size:,.2f}") |
|
|
| def run(self): |
| print("Quantizing the model...") |
|
|
| quantized_inputs = [] |
| sqe = {} |
|
|
| node_idx = 0 |
|
|
| while node_idx < len(self.model.graph.node): |
| node = self.model.graph.node[node_idx] |
|
|
| if node.op_type in SUPPORTED_OPS: |
| for input_idx, input_name in enumerate(node.input): |
| weightCategory = self.get_weight_category(input_name) |
|
|
| |
| if weightCategory == WeightCategory.NONE: |
| continue |
|
|
| weight = self.get_weight_tensor(input_name, weightCategory) |
|
|
| quantized_weights_name = f"{input_name}_quantized" |
| quantized_node_name = f"{input_name}_quantized_node" |
| dequantized_weights_name = f"{input_name}_dequantized" |
| scales_name = f"{input_name}_scales" |
| zero_point_name = f"{input_name}_zero_point" |
|
|
| shape_node_name = f"{input_name}_shape_node" |
| shape_name = f"{input_name}_shape" |
| reshaped_weights_name = f"{input_name}_reshaped" |
|
|
| |
| if weight.size < self.conf.block_size: |
| continue |
|
|
| reshape_needed = weight.ndim > 2 |
|
|
| |
| if input_name in quantized_inputs: |
| node.input[input_idx] = ( |
| reshaped_weights_name |
| if reshape_needed |
| else dequantized_weights_name |
| ) |
| continue |
|
|
|
|
| block_quantize_res = self.block_quantize(weight) |
|
|
| |
| if block_quantize_res.block_size == 1: |
| continue |
|
|
| quantized_inputs.append(input_name) |
|
|
| dequantize_node = create_dequantize_node( |
| quantized_node_name, |
| quantized_weights_name, |
| scales_name, |
| zero_point_name, |
| dequantized_weights_name, |
| block_quantize_res.block_size, |
| block_quantize_res.axis, |
| ) |
|
|
| if reshape_needed: |
| reshape_node = create_reshape_node( |
| shape_node_name, |
| dequantized_weights_name, |
| shape_name, |
| reshaped_weights_name, |
| ) |
|
|
| shape_tensor = onnx.numpy_helper.from_array( |
| np.array(block_quantize_res.original_shape), name=shape_name |
| ) |
| scale_initializer = onnx.numpy_helper.from_array( |
| block_quantize_res.scales, name=scales_name |
| ) |
| zero_point_initializer = onnx.numpy_helper.from_array( |
| block_quantize_res.zero_point, name=zero_point_name |
| ) |
| quantized_weights_initializer = onnx.numpy_helper.from_array( |
| block_quantize_res.quantized_weights, |
| name=quantized_weights_name, |
| ) |
|
|
| dequantized_weights_info = helper.make_tensor_value_info( |
| dequantized_weights_name, |
| onnx.TensorProto.FLOAT, |
| block_quantize_res.quantized_weights.shape, |
| ) |
|
|
| if reshape_needed: |
| shape_info = helper.make_tensor_value_info( |
| reshaped_weights_name, |
| onnx.TensorProto.FLOAT, |
| block_quantize_res.original_shape, |
| ) |
|
|
| self.graph.initializer.extend( |
| [ |
| scale_initializer, |
| zero_point_initializer, |
| shape_tensor, |
| quantized_weights_initializer, |
| ] |
| ) |
|
|
| self.remove_fp32_weights(input_name, weightCategory) |
|
|
| node.input[input_idx] = ( |
| reshaped_weights_name |
| if reshape_needed |
| else dequantized_weights_name |
| ) |
|
|
| |
| if reshape_needed: |
| self.graph.node.insert(0, reshape_node) |
| node_idx += 1 |
|
|
| self.graph.node.insert(0, dequantize_node) |
| node_idx += 1 |
| if reshape_needed: |
| self.graph.value_info.insert(0, shape_info) |
| self.graph.value_info.insert(0, dequantized_weights_info) |
|
|
| sqe[input_name] = block_quantize_res.quantization_error |
|
|
| node_idx += 1 |
|
|
| onnx.checker.check_model(self.model, full_check=True) |
| onnx.save(self.model, self.conf.output_model_path) |
|
|
| self.display_summary(sqe) |
|
|
|
|
| def setup_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Blockwise quantization tool") |
|
|
| parser.add_argument( |
| "-i", |
| "--input_model", |
| type=str, |
| help="The path of onnx model to quantize", |
| required=True, |
| ) |
| parser.add_argument( |
| "-bs", |
| "--block_size", |
| type=int, |
| help="The maximum size of quantization block", |
| required=True, |
| ) |
| parser.add_argument( |
| "-b", |
| "--bits", |
| type=int, |
| help="Quantization bits", |
| choices=[8, 16], |
| default=8, |
| required=False, |
| ) |
| parser.add_argument( |
| "-o", |
| "--output_model", |
| type=str, |
| help="The output model path", |
| default="block_quantized_model.onnx", |
| required=False, |
| ) |
| parser.add_argument( |
| "-v", |
| "--verbose", |
| action="store_true", |
| help="Enable verbose output", |
| required=False, |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = setup_args() |
|
|
| quantization_config = BlockQuantizeConfig( |
| input_model_path=args.input_model, |
| output_model_path=args.output_model, |
| block_size=args.block_size, |
| bits=args.bits, |
| verbose=args.verbose |
| ) |
|
|
| quantizer = BlockQuantizer(quantization_config) |
| quantizer.run() |
|
|