Turn on onnx fallthrough in convert_model() (#15651)
* Turn on ONNX_FALLTHROUGH in torch.onnx.export(). * Removed wrong change. * Added test.
This commit is contained in:
parent
bc663878eb
commit
1e24c51abb
@ -8,11 +8,23 @@ import numpy as np
|
||||
import openvino.runtime as ov
|
||||
import pytest
|
||||
import torch
|
||||
import unittest
|
||||
from openvino.runtime import PartialShape, Dimension, Model, Type
|
||||
|
||||
from common.mo_convert_test_class import CommonMOConvertTest
|
||||
|
||||
|
||||
class MyTorchOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def symbolic(g, in_positions):
|
||||
return g.op("MyTorchOp", in_positions)
|
||||
|
||||
@staticmethod
|
||||
def forward(self, in_positions):
|
||||
out_pos = in_positions.reshape(-1)
|
||||
return out_pos + 0.5
|
||||
|
||||
|
||||
def make_pt_model_one_input():
|
||||
from torch import nn
|
||||
class NeuralNetwork(nn.Module):
|
||||
@ -735,3 +747,30 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
if mo_params is not None:
|
||||
test_params.update(mo_params)
|
||||
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
|
||||
|
||||
|
||||
def create_pt_model_with_custom_op():
|
||||
#
|
||||
# Create PyTorch model with custom operation
|
||||
#
|
||||
import torch.nn as nn
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.my_op = MyTorchOp()
|
||||
|
||||
def forward(self, x):
|
||||
return self.my_op.apply(x)
|
||||
|
||||
return MyModel()
|
||||
|
||||
|
||||
class ConvertONNXFallthroughTest(unittest.TestCase):
|
||||
def test_onnx_fallthrough(self):
|
||||
from openvino.tools.mo import convert_model
|
||||
pytorch_model = create_pt_model_with_custom_op()
|
||||
|
||||
# Check that ONNX conversion passed, so ONNX frontend raises error message of unsupported op.
|
||||
with self.assertRaisesRegex(RuntimeError, ".*OpenVINO does not support the following ONNX operations: MyTorchOp.*"):
|
||||
convert_model(pytorch_model, input_shape=[1, 2, 3], use_legacy_frontend=True)
|
||||
|
@ -131,6 +131,7 @@ def convert_pytorch_to_onnx(model, input_shape, opset_version, example_inputs, o
|
||||
torch.onnx.export(model,
|
||||
inputs,
|
||||
model_onnx,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
**additional_params)
|
||||
return model_onnx
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user