Turn on onnx fallthrough in convert_model() (#15651)
* Turn on ONNX_FALLTHROUGH in torch.onnx.export(). * Removed wrong change. * Added test.
This commit is contained in:
parent
bc663878eb
commit
1e24c51abb
@ -8,11 +8,23 @@ import numpy as np
|
|||||||
import openvino.runtime as ov
|
import openvino.runtime as ov
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import unittest
|
||||||
from openvino.runtime import PartialShape, Dimension, Model, Type
|
from openvino.runtime import PartialShape, Dimension, Model, Type
|
||||||
|
|
||||||
from common.mo_convert_test_class import CommonMOConvertTest
|
from common.mo_convert_test_class import CommonMOConvertTest
|
||||||
|
|
||||||
|
|
||||||
|
class MyTorchOp(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g, in_positions):
|
||||||
|
return g.op("MyTorchOp", in_positions)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(self, in_positions):
|
||||||
|
out_pos = in_positions.reshape(-1)
|
||||||
|
return out_pos + 0.5
|
||||||
|
|
||||||
|
|
||||||
def make_pt_model_one_input():
|
def make_pt_model_one_input():
|
||||||
from torch import nn
|
from torch import nn
|
||||||
class NeuralNetwork(nn.Module):
|
class NeuralNetwork(nn.Module):
|
||||||
@ -735,3 +747,30 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
|||||||
if mo_params is not None:
|
if mo_params is not None:
|
||||||
test_params.update(mo_params)
|
test_params.update(mo_params)
|
||||||
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
|
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
|
||||||
|
|
||||||
|
|
||||||
|
def create_pt_model_with_custom_op():
|
||||||
|
#
|
||||||
|
# Create PyTorch model with custom operation
|
||||||
|
#
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class MyModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModel, self).__init__()
|
||||||
|
self.my_op = MyTorchOp()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.my_op.apply(x)
|
||||||
|
|
||||||
|
return MyModel()
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertONNXFallthroughTest(unittest.TestCase):
|
||||||
|
def test_onnx_fallthrough(self):
|
||||||
|
from openvino.tools.mo import convert_model
|
||||||
|
pytorch_model = create_pt_model_with_custom_op()
|
||||||
|
|
||||||
|
# Check that ONNX conversion passed, so ONNX frontend raises error message of unsupported op.
|
||||||
|
with self.assertRaisesRegex(RuntimeError, ".*OpenVINO does not support the following ONNX operations: MyTorchOp.*"):
|
||||||
|
convert_model(pytorch_model, input_shape=[1, 2, 3], use_legacy_frontend=True)
|
||||||
|
@ -131,6 +131,7 @@ def convert_pytorch_to_onnx(model, input_shape, opset_version, example_inputs, o
|
|||||||
torch.onnx.export(model,
|
torch.onnx.export(model,
|
||||||
inputs,
|
inputs,
|
||||||
model_onnx,
|
model_onnx,
|
||||||
|
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||||
**additional_params)
|
**additional_params)
|
||||||
return model_onnx
|
return model_onnx
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user