Fix transpose with reverse (#6234)

* Fix transpose with reverse

* Add unit test
This commit is contained in:
iliya mironov
2021-06-24 18:12:32 +03:00
committed by GitHub
parent 6736188526
commit 7211cd3aa6
2 changed files with 46 additions and 0 deletions

View File

@@ -25,3 +25,4 @@ class ReverseTransposeNormalization(MiddleReplacementPattern):
const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node()
node.add_input_port(1, skip_if_exist=True)
const.out_port(0).connect(node.in_port(1))
node['reverse_order'] = False

View File

@@ -0,0 +1,45 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.middle.ReverseTransposeNormalization import ReverseTransposeNormalization
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, connect
class ReverseTransposeNormalizationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.nodes_attributes = {
**regular_op_with_shaped_data('placeholder', [1, 10, 20, 3], {'type': 'Parameter'}),
**regular_op_with_shaped_data('transpose', [3, 20, 10, 1],
{'type': 'Transpose', 'op': 'Transpose', 'reverse_order': True}),
**result('result'),
}
cls.ref_nodes_attributes = {
**regular_op_with_shaped_data('placeholder', [1, 10, 20, 3], {'type': 'Parameter'}),
**regular_op_with_shaped_data('transpose', [3, 20, 10, 1],
{'type': 'Transpose', 'op': 'Transpose'}),
**valued_const_with_data('transpose_order', np.array([3, 2, 1, 0])),
**result('result'),
}
def test_splice(self):
graph = build_graph(self.nodes_attributes,
[*connect('placeholder', '0:transpose'),
*connect('transpose', 'result'), ])
ReverseTransposeNormalization().find_and_replace_pattern(graph)
graph.clean_up()
ref_graph = build_graph(self.ref_nodes_attributes,
[*connect('placeholder', '0:transpose'),
*connect('transpose_order', '1:transpose'),
*connect('transpose', 'result'), ]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)