徐知行

onnx 模型量化

onnx 模型量化

onnx 模型转换为 fp18,int8 的代码,包括静态量化和动态量化

预处理

一些量化需要保持 MB 写算子不被量化,或者只量化某些算子.该脚本用于打印 onnx 模型里面的算子类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import onnx,sys

def print_onnx_layers(model_path):
    """
    打印 ONNX 模型中所有层的名称
    参数:
        model_path (str): ONNX 模型文件路径
    """
    # 加载模型
    model = onnx.load(model_path)
    
    # 验证模型格式
    if not isinstance(model, onnx.ModelProto):
        raise ValueError("Invalid ONNX model format")
    
    # 遍历计算图中的所有节点
    print(f"{'Index':<6} {'Layer Name':<40} {'Operator Type':<15}")
    print("-" * 70)
    for i, node in enumerate(model.graph.node):
        layer_name = node.name if node.name else "Unnamed"
        op_type = node.op_type
        # print(node.input)
        print(f"{i:<6} {layer_name:<40} {op_type:<15}    {node.input}")

if __name__ == "__main__":
    # 替换为你的 ONNX 模型路径
    model_path = sys.argv[1]
    
    # 打印模型中的每一层
    print_onnx_layers(model_path)

量化前先做一下预处理,有些模型做预处理可能失败,不做也可:python -m onnxruntime.quantization.preprocess --input models/sv_model/model.onnx --output models/sv_model/model.onnx.opt.onnx

onnx 模型转 fp16:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import onnx
from onnxconverter_common import float16
import sys
import os

DEFAULT_OP_BLOCK_LIST = float16.DEFAULT_OP_BLOCK_LIST+['ArrayFeatureExtractor', 'Binarizer', 'CastMap', 'CategoryMapper', 'DictVectorizer',
                                                       'FeatureVectorizer', 'Imputer', 'LabelEncoder', 'LinearClassifier', 'LinearRegressor',
                                                       'Normalizer', 'OneHotEncoder', 'RandomUniformLike', 'SVMClassifier', 'SVMRegressor', 'Scaler',
                                                       'TreeEnsembleClassifier', 'TreeEnsembleRegressor', 'ZipMap', 'NonMaxSuppression', 'TopK',
                                                       'RoiAlign', 'Resize', 'Range', 'CumSum', 'Min', 'Max', 'Upsample']
DEFAULT_OP_BLOCK_LIST = list(DEFAULT_OP_BLOCK_LIST)
node_block_list = ['/decoder/RandomNormalLike']

# sys.argv[1]是待转换的 onnx 模型
model = onnx.load(sys.argv[1])
model_fp16 = float16.convert_float_to_float16(
    model, keep_io_types=True, op_block_list=DEFAULT_OP_BLOCK_LIST, node_block_list=node_block_list)

print(node_block_list)
onnx.save(model_fp16, sys.argv[1]+"-fp16.onnx")

onnx 使用 int8 动态量化

量化类型选择 QInt8 时,,可能出现量化后的模型无法运行的问题,似乎主要是 conv 引起的,此时量化类型改为QInt8

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import onnx,sys
from onnxruntime.quantization import quantize_dynamic, QuantType

def dynamic_quantization(input_model_path, output_model_path):
    """
    对ONNX模型进行动态INT8量化
    
    参数:
        input_model_path: 原始FP32模型路径
        output_model_path: 量化后INT8模型保存路径
    """
    # 加载原始模型验证
    original_model = onnx.load(input_model_path)
    onnx.checker.check_model(original_model)
    print(f"✅ 原始模型验证通过: {input_model_path}")
    
    # 执行动态量化
    quantize_dynamic(
        input_model_path,          # 输入模型路径
        output_model_path,         # 输出模型路径
        # op_types_to_quantize=['MatMul', 'Attention', 'LSTM', 'Gather', 'Transpose', 'EmbedLayerNormalization'],
        nodes_to_exclude=['Conv','Mul'],
        # activation_type=QuantType.QUInt8,  # 激活值量化类型
        weight_type=QuantType.QUInt8,       # 权重量化类型
        per_channel=False,         # 是否使用逐通道量化
        reduce_range=False,        # 是否减少量化范围(某些CPU需要)
        use_external_data_format=False,  # 是否使用外部数据格式(大型模型需要)
        # optimize_model=True        # 量化前优化模型
    )
    
    print(f"🚀 动态量化完成! 量化模型已保存至: {output_model_path}")

if __name__ == "__main__":
    # 使用示例
    input_model = sys.argv[1]  # 替换为你的模型路径
    output_model = sys.argv[1]+".int8.onnx"  # 量化模型输出路径
    
    dynamic_quantization(input_model, output_model)

int8 静态量化

以下代码同时展示了如何构造校准数据集:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# python -m onnxruntime.quantization.preprocess --input models/sv_model/model.onnx --output models/sv_model/model.onnx.opt.onnx

import librosa
import numpy as np
import glob,os,sys,torchaudio
from onnxruntime.quantization import quantize_static, CalibrationMethod, QuantFormat, QuantType
try:
    from speakerlab.process.processor import FBank
except ImportError:
    sys.path.append('./export')
    from speakerlab.process.processor import FBank
feature_extractor_no_men = FBank(80, sample_rate=16000, mean_nor=False)
def load_wav(wav_file, obj_fs=16000):
    wav, fs = torchaudio.load(wav_file)
    if fs != obj_fs:
        print(f'[WARNING]: The sample rate of {wav_file} is not {obj_fs}, resample it.')
        wav, fs = torchaudio.sox_effects.apply_effects_tensor(
            wav, fs, effects=[['rate', str(obj_fs)]]
        )
    if wav.shape[0] > 1:
        wav = wav[0, :].unsqueeze(0)
    return wav

# 这里准备的校准数据集只需要准备输入就行了
def create_audio_calibration_dataset(audio_dir, num_samples=100):
    """
    创建音频模型的校准数据集
    
    参数:
        audio_dir: 音频文件目录
        num_samples: 样本数量
    """
    all_files = glob.glob(os.path.join(audio_dir, "*.wav"))
    selected_files = np.random.choice(all_files, num_samples, replace=True)
    
    calibration_data = []
    for file_path in selected_files:
        # 加载音频并提取特征
        wav=load_wav(file_path)
        # print("====",file_path)
        feat_no_mean = feature_extractor_no_men(wav).unsqueeze(0).numpy().astype(np.float32)
        calibration_data.append(feat_no_mean)  # 转置为时间序列
    
    return calibration_data


def static_quantization(model_path, quantized_model_path, calibration_dataset):
    """
    执行静态量化
    
    参数:
        model_path: 原始ONNX模型路径
        quantized_model_path: 量化模型保存路径
        calibration_dataset: 校准数据集
    """
    # 创建校准数据读取器
    class CalibrationDataReader:
        def __init__(self, data_list):
            self.data = data_list
            self.index = 0
            
        def get_next(self):
            if self.index < len(self.data):
                # 注意:'input_0' 应替换为模型的输入节点名称
                input_data = {"feat": self.data[self.index]}  
                self.index += 1
                return input_data
            return None
        
        def rewind(self):
            self.index = 0

    # 执行静态量化
    quantize_static(
        model_input=model_path,
        model_output=quantized_model_path,
        calibration_data_reader=CalibrationDataReader(calibration_dataset),
        # op_types_to_quantize=['Conv','Mul',"AveragePool","BatchNormalization","ReduceMean",],
        op_types_to_quantize=['Conv','Mul'],
        quant_format=QuantFormat.QOperator,  # 量化格式
        activation_type=QuantType.QInt8,     # 激活值量化类型
        weight_type=QuantType.QInt8,         # 权重量化类型
        calibrate_method=CalibrationMethod.MinMax,  # 校准方法
        extra_options = {"ExtraSymmetric": True}    # 额外选项
    )
# 使用示例
if __name__ == "__main__":
    # 1. 准备校准数据集
    calibration_data = create_audio_calibration_dataset("./", 100)
    
    # 2. 执行静态量化
    static_quantization(
        "models/sv_model/model.onnx",
        "models/sv_model/model.onnx.sint8.onnx",
        calibration_data
    )

需要注意的是,最好指定量化的算子类型,上面的 demo里面,额外指定了”AveragePool”,”BatchNormalization”,”ReduceMean”算子,速度不仅没有提升,还稍微有些下降. 至于量化格式,可以选择 QDQ 和QOperator,QDQ是和在 x86 服务器上,QOperator适合在 arm 端侧设备上. 另外量化类型也对速度有较大的影响,在 arm 端侧设备上使用 QInt8 相比 QUInt8 有约 20%提升;而在 x86 的 linux 上使用 QInt8 相比 QUInt8 有好几倍的提升,主要是 QUInt8 速度比不量化慢了好几倍. 以上主要是看推理速度这一指标,并没有考虑准确性,实际使用时还要考虑量化误差,综合来选择.

comments powered by Disqus