Fixed conversion of some models with (I)DFT when a layer immediately before (I)DFT is a producer for Result (#9489)
* Fix in the transformation PreserveRuntimeInfo: now Transpose is inserted before input port 0 of Result only, not after data node of layer before Result layer. * Deleted commented code. * Added more tests for the MO transformation PreserveRuntimeInfo.
This commit is contained in:
parent
acdbbf4363
commit
ebcd9eaf07
@ -118,7 +118,7 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
||||
transpose.name = in_node.name
|
||||
in_node.name += "/prev"
|
||||
|
||||
prev_node_out_port.get_connection().insert_node(transpose)
|
||||
op.in_port(0).get_connection().insert_node(transpose)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -6,11 +6,12 @@ import unittest
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from openvino.tools.mo.middle.PreserveRuntimeInfo import PreserveRuntimeInfo
|
||||
from openvino.tools.mo.ops.transpose import Transpose
|
||||
from openvino.tools.mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.middle.PreserveRuntimeInfo import PreserveRuntimeInfo
|
||||
from openvino.tools.mo.ops.op import PermuteAttrs
|
||||
from openvino.tools.mo.ops.transpose import Transpose
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from openvino.tools.mo.utils.runtime_info import RTInfo
|
||||
from unit_tests.utils.graph import build_graph, connect, valued_const_with_data, regular_op_with_empty_data, \
|
||||
@ -34,6 +35,65 @@ edges_with_transpose = [*connect('placeholder1', '0:transpose_parameter'),
|
||||
*connect('transpose_result', 'result')]
|
||||
|
||||
|
||||
nodes_for_case_with_two_results = {
|
||||
'placeholder1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
|
||||
'placeholder2': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
|
||||
'add': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'infer': copy_shape_infer},
|
||||
'add_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
|
||||
'result1': {'kind': 'op', 'op': 'Result'},
|
||||
'result2': {'kind': 'op', 'op': 'Result'},
|
||||
'fft': {'kind': 'op', 'op': 'IDFT', 'type': 'IDFT', 'infer': copy_shape_infer},
|
||||
'fft_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
|
||||
'fft_axes': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([1]), 'value': int64_array([-1])
|
||||
},
|
||||
'fft_axes_data': {'value': int64_array([-1]), 'shape': int64_array([1]), 'kind': 'data', 'data_type': np.int64},
|
||||
'transpose_parameter_order': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': None, 'value': None
|
||||
},
|
||||
'transpose_parameter_order_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.int64},
|
||||
'transpose_parameter': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'infer': Transpose.infer},
|
||||
'transpose_parameter_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
}
|
||||
|
||||
edges_for_case_with_two_results = [
|
||||
('transpose_parameter_order', 'transpose_parameter_order_data'),
|
||||
('transpose_parameter_order_data', 'transpose_parameter', {'in': 1}),
|
||||
('transpose_parameter', 'transpose_parameter_data'),
|
||||
('placeholder1', 'placeholder1_data'),
|
||||
('placeholder2', 'placeholder2_data'),
|
||||
('placeholder1_data', 'add', {'in': 0}),
|
||||
('placeholder2_data', 'add', {'in': 1}),
|
||||
('add', 'add_data'),
|
||||
('add_data', 'result1', {'out': 0, 'in': 0}),
|
||||
('add_data', 'fft', {'out': 0, 'in': 0}),
|
||||
('fft_axes', 'fft_axes_data'),
|
||||
('fft_axes_data', 'fft', {'in': 1}),
|
||||
('fft', 'fft_data'),
|
||||
('fft_data', 'result2'),
|
||||
]
|
||||
|
||||
edges_with_transpose_for_case_with_two_results = [
|
||||
('transpose_parameter_order', 'transpose_parameter_order_data'),
|
||||
('placeholder1_data', 'transpose_parameter', {'in': 0}),
|
||||
('transpose_parameter_order_data', 'transpose_parameter', {'in': 1}),
|
||||
('transpose_parameter', 'transpose_parameter_data'),
|
||||
('placeholder1', 'placeholder1_data'),
|
||||
('placeholder2', 'placeholder2_data'),
|
||||
('transpose_parameter_data', 'add', {'in': 0}),
|
||||
('placeholder2_data', 'add', {'in': 1}),
|
||||
('add', 'add_data'),
|
||||
('add_data', 'result1', {'out': 0, 'in': 0}),
|
||||
('add_data', 'fft', {'out': 0, 'in': 0}),
|
||||
('fft_axes', 'fft_axes_data'),
|
||||
('fft_axes_data', 'fft', {'in': 1}),
|
||||
('fft', 'fft_data'),
|
||||
('fft_data', 'result2'),
|
||||
]
|
||||
|
||||
|
||||
@generator
|
||||
class PreserveRuntimeInfoTest(unittest.TestCase):
|
||||
@generate(*[
|
||||
@ -122,3 +182,65 @@ class PreserveRuntimeInfoTest(unittest.TestCase):
|
||||
rt_info = result_node.rt_info.info
|
||||
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||
self.assertTrue(np.array_equal(old_api_map['order'], [0, 3, 1, 2]))
|
||||
|
||||
@generate(*[
|
||||
([0, 3, 1, 2], [0, 2, 3, 1], True, 'DFT'),
|
||||
([0, 3, 1, 2], [0, 2, 3, 1], True, 'IDFT'),
|
||||
(None, None, False, 'DFT'),
|
||||
(None, None, False, 'IDFT'),
|
||||
([0, 4, 1, 2, 3], [0, 2, 3, 4, 1], True, 'DFT'),
|
||||
([0, 4, 1, 2, 3], [0, 2, 3, 4, 1], True, 'IDFT'),
|
||||
])
|
||||
def test_transpose_insert_with_two_result_nodes(self, nhwc_to_nchw_order, nchw_to_nhwc_order,
|
||||
add_permutation_attrs, fft_kind):
|
||||
shape_len = len(nhwc_to_nchw_order) if add_permutation_attrs else 3
|
||||
shape = np.array(range(shape_len))
|
||||
add_shape = shape if nhwc_to_nchw_order is None else shape[nhwc_to_nchw_order]
|
||||
graph = build_graph(nodes_attrs=nodes_for_case_with_two_results,
|
||||
edges=edges_for_case_with_two_results,
|
||||
update_attributes={
|
||||
'placeholder1_data': {'shape': int64_array(shape)},
|
||||
'placeholder1': {'shape': int64_array(shape), 'rt_info': RTInfo()},
|
||||
'transpose_parameter_order': {
|
||||
'value': np.array(nhwc_to_nchw_order),
|
||||
'shape': int64_array(np.array(nhwc_to_nchw_order).shape)
|
||||
},
|
||||
'transpose_parameter_order_data': {
|
||||
'value': np.array(nhwc_to_nchw_order),
|
||||
'shape': int64_array(np.array(nhwc_to_nchw_order).shape)
|
||||
},
|
||||
'fft': {'op': fft_kind, 'type': fft_kind},
|
||||
'add_data': {'shape': add_shape},
|
||||
'fft_data': {'shape': add_shape},
|
||||
'result1': {'shape': shape, 'rt_info': RTInfo()},
|
||||
'result2': {'shape': shape, 'rt_info': RTInfo()},
|
||||
})
|
||||
|
||||
if add_permutation_attrs:
|
||||
graph_ref = build_graph(nodes_for_case_with_two_results, edges_with_transpose_for_case_with_two_results)
|
||||
else:
|
||||
graph_ref = build_graph(nodes_for_case_with_two_results, edges_for_case_with_two_results)
|
||||
|
||||
param1_node = Node(graph, 'placeholder1')
|
||||
result1_node = Node(graph, 'result1')
|
||||
result2_node = Node(graph, 'result2')
|
||||
|
||||
if add_permutation_attrs:
|
||||
shape_len = len(nhwc_to_nchw_order)
|
||||
param1_node['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
|
||||
param1_node.out_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
|
||||
result1_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
|
||||
result2_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
|
||||
|
||||
PreserveRuntimeInfo().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result1')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
self.assertFalse(param1_node.has_valid('permute_attrs'))
|
||||
self.assertFalse(param1_node.out_node(0).has_valid('permutation'))
|
||||
|
||||
if add_permutation_attrs:
|
||||
rt_info = param1_node.rt_info.info
|
||||
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||
self.assertTrue(np.array_equal(old_api_map['inverse_order'], nchw_to_nhwc_order))
|
||||
|
Loading…
Reference in New Issue
Block a user