Turn on onnx fallthrough in convert_model() (#15651)

* Turn on ONNX_FALLTHROUGH in torch.onnx.export().

* Removed wrong change.

* Added test.
This commit is contained in:
Anastasiia Pnevskaia 2023-02-23 10:22:30 +01:00 committed by GitHub
parent bc663878eb
commit 1e24c51abb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View File

@ -8,11 +8,23 @@ import numpy as np
import openvino.runtime as ov import openvino.runtime as ov
import pytest import pytest
import torch import torch
import unittest
from openvino.runtime import PartialShape, Dimension, Model, Type from openvino.runtime import PartialShape, Dimension, Model, Type
from common.mo_convert_test_class import CommonMOConvertTest from common.mo_convert_test_class import CommonMOConvertTest
class MyTorchOp(torch.autograd.Function):
@staticmethod
def symbolic(g, in_positions):
return g.op("MyTorchOp", in_positions)
@staticmethod
def forward(self, in_positions):
out_pos = in_positions.reshape(-1)
return out_pos + 0.5
def make_pt_model_one_input(): def make_pt_model_one_input():
from torch import nn from torch import nn
class NeuralNetwork(nn.Module): class NeuralNetwork(nn.Module):
@ -735,3 +747,30 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
if mo_params is not None: if mo_params is not None:
test_params.update(mo_params) test_params.update(mo_params)
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False) self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
def create_pt_model_with_custom_op():
#
# Create PyTorch model with custom operation
#
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.my_op = MyTorchOp()
def forward(self, x):
return self.my_op.apply(x)
return MyModel()
class ConvertONNXFallthroughTest(unittest.TestCase):
def test_onnx_fallthrough(self):
from openvino.tools.mo import convert_model
pytorch_model = create_pt_model_with_custom_op()
# Check that ONNX conversion passed, so ONNX frontend raises error message of unsupported op.
with self.assertRaisesRegex(RuntimeError, ".*OpenVINO does not support the following ONNX operations: MyTorchOp.*"):
convert_model(pytorch_model, input_shape=[1, 2, 3], use_legacy_frontend=True)

View File

@ -131,6 +131,7 @@ def convert_pytorch_to_onnx(model, input_shape, opset_version, example_inputs, o
torch.onnx.export(model, torch.onnx.export(model,
inputs, inputs,
model_onnx, model_onnx,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
**additional_params) **additional_params)
return model_onnx return model_onnx