[MO] fix UnaryElementwise reverse_infer (#15366)

* fix UnaryElementwise reverse_infer

* fixed tests for UnaryElementwise reverse_infer

* reverted autocorrection edits
This commit is contained in:
Pavel Esir 2023-01-31 08:56:24 +01:00 committed by GitHub
parent d092f5d7dd
commit 1ae0b2796e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 2 deletions

View File

@ -4,7 +4,7 @@
import numpy as np
from openvino.tools.mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer, eltwise_reverse_infer
from openvino.tools.mo.front.common.partial_infer.utils import float32_array
from openvino.tools.mo.front.common.partial_infer.utils import float32_array, reverse_bypass_infer
from openvino.tools.mo.graph.graph import Graph, Node
from openvino.tools.mo.middle.passes.infer import copy_type_infer
from openvino.tools.mo.ops.op import Op
@ -48,6 +48,7 @@ class UnaryElementwise(Elementwise):
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {**{
'in_ports_count': 1,
'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),
}, **attrs})
@staticmethod

View File

@ -8,7 +8,7 @@ from generator import generator, generate
from openvino.tools.mo.front.common.partial_infer.eltwise import eltwise_infer, eltwise_reverse_infer
from openvino.tools.mo.front.common.partial_infer.utils import shape_array, strict_compare_tensors, \
dynamic_dimension_value
dynamic_dimension_value, reverse_bypass_infer
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.middle.passes.infer import partial_infer
from openvino.tools.mo.ops.parameter import Parameter
@ -205,3 +205,41 @@ class TestElementwiseReverseInfer(unittest.TestCase):
out_shape=[1, dyn, dyn, 1],
ref_shape=[1, 4, dyn, 1],
auto_broadcast='none')
class TestUnaryElementwiseReverseInfer(unittest.TestCase):
@staticmethod
def build_and_test_reverse_inference(out_shape):
nodes = {
**shaped_parameter('undefined_shape_data', None, {'reverse_infer': Parameter.reverse_infer}),
**regular_op_with_empty_data('elementwise',
{'op': 'Sqrt', 'type': 'Sqrt',
'infer': eltwise_infer,
'reverse_infer': lambda node: reverse_bypass_infer(node,in_ports=[0])}),
**result('res'),
}
edges = [
*connect('undefined_shape_data', '0:elementwise'),
*connect('elementwise', 'res'),
]
graph = build_graph(nodes, edges)
graph.stage = 'middle'
Node(graph, 'elementwise').out_port(0).data.set_shape(shape_array(out_shape))
partial_infer(graph)
actual_shape = Node(graph, 'elementwise').in_port(0).data.get_shape()
# check that out_shape is transferred into only existing in_port(0)
assert strict_compare_tensors(actual_shape, shape_array(out_shape))
def test_reverse_infer_1(self):
self.build_and_test_reverse_inference(out_shape=[dyn, dyn, dyn, dyn])
def test_reverse_infer_2(self):
self.build_and_test_reverse_inference(out_shape=[dyn, dyn])
def test_reverse_infer_3(self):
self.build_and_test_reverse_inference(out_shape=[1, 100])