diff --git a/tools/mo/openvino/tools/mo/middle/PreserveRuntimeInfo.py b/tools/mo/openvino/tools/mo/middle/PreserveRuntimeInfo.py index a3ba0657f05..7bdb1c4f02f 100644 --- a/tools/mo/openvino/tools/mo/middle/PreserveRuntimeInfo.py +++ b/tools/mo/openvino/tools/mo/middle/PreserveRuntimeInfo.py @@ -118,7 +118,7 @@ class PreserveRuntimeInfo(MiddleReplacementPattern): transpose.name = in_node.name in_node.name += "/prev" - prev_node_out_port.get_connection().insert_node(transpose) + op.in_port(0).get_connection().insert_node(transpose) else: continue diff --git a/tools/mo/unit_tests/mo/middle/PreserveRuntimeInfo_test.py b/tools/mo/unit_tests/mo/middle/PreserveRuntimeInfo_test.py index 7a39ae62bf4..ac5f822b4ed 100644 --- a/tools/mo/unit_tests/mo/middle/PreserveRuntimeInfo_test.py +++ b/tools/mo/unit_tests/mo/middle/PreserveRuntimeInfo_test.py @@ -6,11 +6,12 @@ import unittest import numpy as np from generator import generator, generate -from openvino.tools.mo.middle.PreserveRuntimeInfo import PreserveRuntimeInfo -from openvino.tools.mo.ops.transpose import Transpose from openvino.tools.mo.front.common.partial_infer.elemental import copy_shape_infer +from openvino.tools.mo.front.common.partial_infer.utils import int64_array from openvino.tools.mo.graph.graph import Node +from openvino.tools.mo.middle.PreserveRuntimeInfo import PreserveRuntimeInfo from openvino.tools.mo.ops.op import PermuteAttrs +from openvino.tools.mo.ops.transpose import Transpose from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs from openvino.tools.mo.utils.runtime_info import RTInfo from unit_tests.utils.graph import build_graph, connect, valued_const_with_data, regular_op_with_empty_data, \ @@ -34,6 +35,65 @@ edges_with_transpose = [*connect('placeholder1', '0:transpose_parameter'), *connect('transpose_result', 'result')] +nodes_for_case_with_two_results = { + 'placeholder1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'placeholder1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32}, + 'placeholder2': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'placeholder2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32}, + 'add': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'infer': copy_shape_infer}, + 'add_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32}, + 'result1': {'kind': 'op', 'op': 'Result'}, + 'result2': {'kind': 'op', 'op': 'Result'}, + 'fft': {'kind': 'op', 'op': 'IDFT', 'type': 'IDFT', 'infer': copy_shape_infer}, + 'fft_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32}, + 'fft_axes': { + 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([1]), 'value': int64_array([-1]) + }, + 'fft_axes_data': {'value': int64_array([-1]), 'shape': int64_array([1]), 'kind': 'data', 'data_type': np.int64}, + 'transpose_parameter_order': { + 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': None, 'value': None + }, + 'transpose_parameter_order_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.int64}, + 'transpose_parameter': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'infer': Transpose.infer}, + 'transpose_parameter_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, +} + +edges_for_case_with_two_results = [ + ('transpose_parameter_order', 'transpose_parameter_order_data'), + ('transpose_parameter_order_data', 'transpose_parameter', {'in': 1}), + ('transpose_parameter', 'transpose_parameter_data'), + ('placeholder1', 'placeholder1_data'), + ('placeholder2', 'placeholder2_data'), + ('placeholder1_data', 'add', {'in': 0}), + ('placeholder2_data', 'add', {'in': 1}), + ('add', 'add_data'), + ('add_data', 'result1', {'out': 0, 'in': 0}), + ('add_data', 'fft', {'out': 0, 'in': 0}), + ('fft_axes', 'fft_axes_data'), + ('fft_axes_data', 'fft', {'in': 1}), + ('fft', 'fft_data'), + ('fft_data', 'result2'), +] + +edges_with_transpose_for_case_with_two_results = [ + ('transpose_parameter_order', 'transpose_parameter_order_data'), + ('placeholder1_data', 'transpose_parameter', {'in': 0}), + ('transpose_parameter_order_data', 'transpose_parameter', {'in': 1}), + ('transpose_parameter', 'transpose_parameter_data'), + ('placeholder1', 'placeholder1_data'), + ('placeholder2', 'placeholder2_data'), + ('transpose_parameter_data', 'add', {'in': 0}), + ('placeholder2_data', 'add', {'in': 1}), + ('add', 'add_data'), + ('add_data', 'result1', {'out': 0, 'in': 0}), + ('add_data', 'fft', {'out': 0, 'in': 0}), + ('fft_axes', 'fft_axes_data'), + ('fft_axes_data', 'fft', {'in': 1}), + ('fft', 'fft_data'), + ('fft_data', 'result2'), +] + + @generator class PreserveRuntimeInfoTest(unittest.TestCase): @generate(*[ @@ -122,3 +182,65 @@ class PreserveRuntimeInfoTest(unittest.TestCase): rt_info = result_node.rt_info.info old_api_map = rt_info[('old_api_map_order', 0)].info self.assertTrue(np.array_equal(old_api_map['order'], [0, 3, 1, 2])) + + @generate(*[ + ([0, 3, 1, 2], [0, 2, 3, 1], True, 'DFT'), + ([0, 3, 1, 2], [0, 2, 3, 1], True, 'IDFT'), + (None, None, False, 'DFT'), + (None, None, False, 'IDFT'), + ([0, 4, 1, 2, 3], [0, 2, 3, 4, 1], True, 'DFT'), + ([0, 4, 1, 2, 3], [0, 2, 3, 4, 1], True, 'IDFT'), + ]) + def test_transpose_insert_with_two_result_nodes(self, nhwc_to_nchw_order, nchw_to_nhwc_order, + add_permutation_attrs, fft_kind): + shape_len = len(nhwc_to_nchw_order) if add_permutation_attrs else 3 + shape = np.array(range(shape_len)) + add_shape = shape if nhwc_to_nchw_order is None else shape[nhwc_to_nchw_order] + graph = build_graph(nodes_attrs=nodes_for_case_with_two_results, + edges=edges_for_case_with_two_results, + update_attributes={ + 'placeholder1_data': {'shape': int64_array(shape)}, + 'placeholder1': {'shape': int64_array(shape), 'rt_info': RTInfo()}, + 'transpose_parameter_order': { + 'value': np.array(nhwc_to_nchw_order), + 'shape': int64_array(np.array(nhwc_to_nchw_order).shape) + }, + 'transpose_parameter_order_data': { + 'value': np.array(nhwc_to_nchw_order), + 'shape': int64_array(np.array(nhwc_to_nchw_order).shape) + }, + 'fft': {'op': fft_kind, 'type': fft_kind}, + 'add_data': {'shape': add_shape}, + 'fft_data': {'shape': add_shape}, + 'result1': {'shape': shape, 'rt_info': RTInfo()}, + 'result2': {'shape': shape, 'rt_info': RTInfo()}, + }) + + if add_permutation_attrs: + graph_ref = build_graph(nodes_for_case_with_two_results, edges_with_transpose_for_case_with_two_results) + else: + graph_ref = build_graph(nodes_for_case_with_two_results, edges_for_case_with_two_results) + + param1_node = Node(graph, 'placeholder1') + result1_node = Node(graph, 'result1') + result2_node = Node(graph, 'result2') + + if add_permutation_attrs: + shape_len = len(nhwc_to_nchw_order) + param1_node['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')]) + param1_node.out_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len) + result1_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len) + result2_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len) + + PreserveRuntimeInfo().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result1') + self.assertTrue(flag, resp) + + self.assertFalse(param1_node.has_valid('permute_attrs')) + self.assertFalse(param1_node.out_node(0).has_valid('permutation')) + + if add_permutation_attrs: + rt_info = param1_node.rt_info.info + old_api_map = rt_info[('old_api_map_order', 0)].info + self.assertTrue(np.array_equal(old_api_map['inverse_order'], nchw_to_nhwc_order))