Turn on onnx fallthrough in convert_model() (#15651)

* Turn on ONNX_FALLTHROUGH in torch.onnx.export().

* Removed wrong change.

* Added test.
This commit is contained in:
Anastasiia Pnevskaia 2023-02-23 10:22:30 +01:00 committed by GitHub
parent bc663878eb
commit 1e24c51abb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View File

@ -8,11 +8,23 @@ import numpy as np
import openvino.runtime as ov
import pytest
import torch
import unittest
from openvino.runtime import PartialShape, Dimension, Model, Type
from common.mo_convert_test_class import CommonMOConvertTest
class MyTorchOp(torch.autograd.Function):
@staticmethod
def symbolic(g, in_positions):
return g.op("MyTorchOp", in_positions)
@staticmethod
def forward(self, in_positions):
out_pos = in_positions.reshape(-1)
return out_pos + 0.5
def make_pt_model_one_input():
from torch import nn
class NeuralNetwork(nn.Module):
@ -735,3 +747,30 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
if mo_params is not None:
test_params.update(mo_params)
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
def create_pt_model_with_custom_op():
#
# Create PyTorch model with custom operation
#
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.my_op = MyTorchOp()
def forward(self, x):
return self.my_op.apply(x)
return MyModel()
class ConvertONNXFallthroughTest(unittest.TestCase):
def test_onnx_fallthrough(self):
from openvino.tools.mo import convert_model
pytorch_model = create_pt_model_with_custom_op()
# Check that ONNX conversion passed, so ONNX frontend raises error message of unsupported op.
with self.assertRaisesRegex(RuntimeError, ".*OpenVINO does not support the following ONNX operations: MyTorchOp.*"):
convert_model(pytorch_model, input_shape=[1, 2, 3], use_legacy_frontend=True)

View File

@ -131,6 +131,7 @@ def convert_pytorch_to_onnx(model, input_shape, opset_version, example_inputs, o
torch.onnx.export(model,
inputs,
model_onnx,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
**additional_params)
return model_onnx