Loading... ## Pytorch 模型导出为 ONNX 模型格式 PyTorch 是一个流行的开源机器学习库,它提供了灵活的计算图结构和强大的神经网络工具。`torch.onnx` 模块允许开发者将 PyTorch 的 `torch.nn.Module` 中定义的模型转换为开放式神经网络交换格式(ONNX)。ONNX 是一个跨平台的模型表示标准,使不同的深度学习框架可以共享模型。 PyTorch 提供了两种主要的方式来执行模型的 ONNX 导出: 1. **torch.onnx.export** 该 API 是基于 TorchScript 的,TorchScript 是 PyTorch 的一个子集,可以被用来创建可序列化和可优化的模型。`torch.onnx.export` 函数可以捕获 `torch.nn.Module` 中的计算图,并将其序列化为 ONNX 格式,该方法自 PyTorch 1.2.0 版本起就已经可用。 2. **torch.onnx.dynamo_export** 该 API,基于 PyTorch 2.0 发布的 TorchDynamo 技术。TorchDynamo 是一个基于 Python 的编译器堆栈,旨在自动优化 PyTorch 程序。`torch.onnx.dynamo_export` 函数也是用于将 `torch.nn.Module` 中的计算图转换为 ONNX 格式,但它是利用 TorchDynamo 对模型的导出流程进行了进一步优化。 本文介绍如何利用 `torch.onnx.export` 来实现模型转换并导出为 ONNX 格式,如果对 `torch.onnx.dynamo_export` 有兴趣可以参考 [将 PyTorch 模型导出到 ONNX](https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html) 进行尝试。 ### 1、安装所需的依赖项 由于 ONNX 导出器使用 onnx 并将 onnxscript PyTorch 运算符转换为 ONNX 运算符,因此我们需要安装它们。 ```bash pip install --upgrade onnx pip install --upgrade onnxscript pip install --upgrade onnxruntime ``` ### 2、编写神经网络模型 设置好环境后,我们需要使用 PyTorch 创建一个简单的神经网络模型,并根据定义的模型创建一个模型实例。同时,还需要对模型进行重新训练或者直接为模型加载预训练参数,然后将模型设置为推理模式,以应对 dropout 这类运算在训练和推理时的行为差异。 ```python # 模型实例化 my_model = models.resnet18(weights=None) ## `pretrained` 设置为 `True` 可以直接加载预训练参数 # my_model = models.resnet18(pretrained=True) # 加载模型预训练参数 model_weights = torch.load("/path/to/weights/resnet18-f37072fd.pth", weights_only=True) my_model.load_state_dict(model_weights) # 将模型转换为推理模式(必须) my_model.eval() ``` ### 3、将模型导出为 ONNX 格式 然后,我们可以借助 torch.onnx.export 函数就可以实现模型转换并导出成 ONNX 格式。在 ONNX 中,默认情况下模型的输入和输出尺寸是固定的。但是,您可以通过指定 `dynamic_axes` 参数使某些维度变为动态的。示例代码中,将第一个维度设置为动态,允许输入尺寸为 [batch_size, 3, 224, 224],其中 batch_size 可以是任意值。 ```python import torch import torch.onnx from torch import nn # 准备输入张量x,只要其类型和尺寸符合要求,其值可以是任意的 input_tensor = torch.randn(1, 3, 224, 224, requires_grad=True) # 实现模型转换并导出成 ONNX 格式 torch.onnx.export( model = my_model, args = input_tensor, f = "/path/to/save/model.onnx", export_params = True, opset_version = 10, do_constant_folding = True, input_names = ['input'], output_names = ['output'], dynamic_axes={ 'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'} } ) # 用于后续验证导出的 ONNX 模型计算值与原模型是否相等 pytorch_output = my_model(input_tensor) ``` 最后,您可以将 ONNX 文件重新加载到内存中,并使用以下代码检查其格式是否正确: ```python import onnx onnx_model = onnx.load("model.onnx") onnx.checker.check_model(onnx_model) ``` ### 4、使用 Netron 可视化 ONNX 模型图 模型保存成功后,您可使用 Netron 进行可视化。Netron 支持在 macOS、Linux、或 Windows 上本地安装,或者可以通过浏览器在线使用。可以访问 https://netron.app/ 体验 Netron 的网页版来查看模型。 ![](https://www.huarzone.com/usr/uploads/images/blog-images/210/4120720413.png) ### 5、使用 ONNX Runtime 执行 ONNX 模型 ONNX 标准并不支持 PyTorch 支持的所有数据结构和类型,因此我们需要先将 PyTorch 输入调整为 NumPy 数组,然后再将其提供给 ONNX Runtime。在准备数据供 ONNX Runtime 使用时,转换后的数组需要放入一个字典结构中,这个字典的键是模型输入层的名称,值则是对应的 NumPy 数组。 随后,我们可以初始化一个 ONNX Runtime 推理会话,并使用处理后的输入执行 ONNX 模型并获取输出。虽然本例中 ONNX 推理会话是在 CPU 上进行的,但其同样支持在 GPU 上进行推理,提供了灵活的平台选择。 ```python import numpy as np import onnxruntime # 用于将 PyTorch 张量转换为 NumPy 数组 def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() # 创建一个 ONNX 运行时的会话来进行推理 ort_session = onnxruntime.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_tensor)} ort_output = ort_session.run(None, ort_inputs) ``` ### 6、将 PyTorch 结果与 ONNX 运行时的结果进行比较 为了验证 ONNX 模型和原始 PyTorch 模型的一致性,我们需确保两者使用相同的输入数据情况下,比较两个模型的输出(确保格式相同)。如果两者的输出结果在指定的相对容忍度(rtol=1e-03)和绝对容忍度(atol=1e-05)范围内不一致,系统会触发一个异常。 ```python # 验证 ONNX 运行时的输出和 PyTorch 的输出是否接近 try: np.testing.assert_allclose(to_numpy(pytorch_output), ort_output[0], rtol=1e-03, atol=1e-05) print("PyTorch and ONNX Runtime output matched!") except AssertionError as e: print("Detected output discrepancy: ", e) ``` 我们已顺利完成了将 PyTorch 模型转换为 ONNX 格式的全部步骤,包括模型创建并加载预训练参数、模型格式转换及导出,以及利用 Netron 工具进行可视化查看。此外,我们还利用 ONNX Runtime 运行了该模型,并对 ONNX 模型的输出结果与原始 PyTorch 模型的输出结果进行了数值比较验证。 Last modification:August 8, 2024 © Allow specification reprint Support Appreciate the author AliPayWeChat Like 如果觉得我的文章对你有用,请随意赞赏