[MO] fix UnaryElementwise reverse_infer (#15366)
* fix UnaryElementwise reverse_infer * fixed tests for UnaryElementwise reverse_infer * reverted autocorrection edits
This commit is contained in:
parent
d092f5d7dd
commit
1ae0b2796e
@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer, eltwise_reverse_infer
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import float32_array
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import float32_array, reverse_bypass_infer
|
||||
from openvino.tools.mo.graph.graph import Graph, Node
|
||||
from openvino.tools.mo.middle.passes.infer import copy_type_infer
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
@ -48,6 +48,7 @@ class UnaryElementwise(Elementwise):
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {**{
|
||||
'in_ports_count': 1,
|
||||
'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),
|
||||
}, **attrs})
|
||||
|
||||
@staticmethod
|
||||
|
@ -8,7 +8,7 @@ from generator import generator, generate
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.eltwise import eltwise_infer, eltwise_reverse_infer
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import shape_array, strict_compare_tensors, \
|
||||
dynamic_dimension_value
|
||||
dynamic_dimension_value, reverse_bypass_infer
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.middle.passes.infer import partial_infer
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
@ -205,3 +205,41 @@ class TestElementwiseReverseInfer(unittest.TestCase):
|
||||
out_shape=[1, dyn, dyn, 1],
|
||||
ref_shape=[1, 4, dyn, 1],
|
||||
auto_broadcast='none')
|
||||
|
||||
|
||||
class TestUnaryElementwiseReverseInfer(unittest.TestCase):
|
||||
@staticmethod
|
||||
def build_and_test_reverse_inference(out_shape):
|
||||
|
||||
nodes = {
|
||||
**shaped_parameter('undefined_shape_data', None, {'reverse_infer': Parameter.reverse_infer}),
|
||||
**regular_op_with_empty_data('elementwise',
|
||||
{'op': 'Sqrt', 'type': 'Sqrt',
|
||||
'infer': eltwise_infer,
|
||||
'reverse_infer': lambda node: reverse_bypass_infer(node,in_ports=[0])}),
|
||||
**result('res'),
|
||||
}
|
||||
|
||||
edges = [
|
||||
*connect('undefined_shape_data', '0:elementwise'),
|
||||
*connect('elementwise', 'res'),
|
||||
]
|
||||
|
||||
graph = build_graph(nodes, edges)
|
||||
graph.stage = 'middle'
|
||||
Node(graph, 'elementwise').out_port(0).data.set_shape(shape_array(out_shape))
|
||||
|
||||
partial_infer(graph)
|
||||
actual_shape = Node(graph, 'elementwise').in_port(0).data.get_shape()
|
||||
|
||||
# check that out_shape is transferred into only existing in_port(0)
|
||||
assert strict_compare_tensors(actual_shape, shape_array(out_shape))
|
||||
|
||||
def test_reverse_infer_1(self):
|
||||
self.build_and_test_reverse_inference(out_shape=[dyn, dyn, dyn, dyn])
|
||||
|
||||
def test_reverse_infer_2(self):
|
||||
self.build_and_test_reverse_inference(out_shape=[dyn, dyn])
|
||||
|
||||
def test_reverse_infer_3(self):
|
||||
self.build_and_test_reverse_inference(out_shape=[1, 100])
|
||||
|
Loading…
Reference in New Issue
Block a user