閱讀時間約 2 分鐘

809 字

現在的模型都越來越巨大了,尤其是語言模型。搜尋引擎需要即時回傳結果,手機上的模型只有極小的記憶體以及計算能力,都需要將模型輕量化。ONNX是微軟開發的跨平台機器學習套件,可以將各種框架 (PyTorch、TensorFlow 等等) 的模型轉成 ONNX 格式,並且做輕量化。本文以 PyTorch 模型作為範例,使用 ONNX 將之輕量化,在相似精確度下獲得更小的模型、更快的推論速度。所有程式碼都會放在 GitHub model_quatization 上。

將 PyTorch 模型轉成 ONNX 格式

將模型設為 eval 模式以後,提供範例輸入便可使用 torch.onnx.export 將 PyTorch 模型轉成 ONNX 格式。其中 dynamic_axes 是告訴 ONNX 我們的 batch size 跟 sequence length 為動態長度。

example_inputs = tokenizer("query: this is a test sentence", return_tensors="pt")
model.eval()
torch.onnx.export(
    model,
    tuple((example_inputs['input_ids'], example_inputs['attention_mask'])),
    "models/distilbert-imdb.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["output"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "max_seq_len"},
        "attention_mask": {0: "batch_size", 1: "max_seq_len"},
        "output": {0: "batch_size"},
    },
    opset_version=17,
    export_params=True,
    do_constant_folding=True,
)

模型量化

模型量化 (Quantization) 是將用浮點數表示權重的模型,將大部分的權重以整數表示,用來降低模型大小以及加快推論速度。從 32 bit 浮點數轉成 8 bit 整數,模型就小了四分之一,而浮點數的計算也比整數複雜,需要更久的時間。

量化分成訓練後量化 (Post-training Quantization) 以及考慮量化的訓練 (Quantization-aware Training),前者可以直接將模型量化,後者是在模型訓練時模擬量化後的模型行為,更好地保持模型的精確度。訓練後量化實作相對簡單,本文就以訓練後量化作為範例。

將 ONNX 模型量化

量化分成動態量化與靜態量化 (dynamic and static quantization),其中動態量化是在執行推論時才根據推論的資料計算量化後的 activation functions 的參數,而靜態量化是事先用一個資料集將參數計算好。

ONNX 官方建議 RNN 與 transformer-based models 使用動態量化,CNN models 使用靜態量化。

quantize_dynamic(
        model_input="models/distilbert-imdb.onnx",
        model_output="models/distilbert-imdb.int8.onnx",
        weight_type=QuantType.QInt8,
        extra_options=dict(
            EnableSubgraph=True
        ),
    )

實驗使用 HuggingFace 上面其中一個以 IMDB dataset finetune 的 distilbert

以下是跑在 Macbook Air M1 CPU 與 Windows 10 WSL (Linux 子系統) i5-8400 CPU 上面的結果 (不同平台可能會有不同的結果):

模型大小每筆資料推論時間準確率
PyTorch Model (MAC)256MB71.1ms93.8%
ONNX Model(MAC)256MB113.5ms93.8%
ONNX 8-bit Model(MAC)64MB87.7ms93.75%
PyTorch Model (Win)256MB78.6ms93.8%
ONNX Model(Win)256MB85.1ms93.8%
ONNX 8-bit Model(Win)64MB61.1ms93.85%

GPU

ONNX 將模型量化後,如果使用不支援 GPU 推論的 operator,會使用 CPU 計算。此時會有額外的資料傳輸成本,有可能造成量化後的模型比量化前更慢。

參考資料

  1. ONNX Runtime
  2. GitHub nixiesearch/onnx-convert
  3. Distilbert IMDB finetune 版本
  4. IMDB dataset
comments powered by Disqus

最新文章

分類

標籤