Fix incorrect working UnpackPackReverseInputChannels for centernet (#9201)

* fix UnpackPackReverseInputChannels

* Add UnpackPackReverseInputChannels test
This commit is contained in:
Eugeny Volosenkov
2021-12-27 15:57:02 +03:00
committed by GitHub
parent a83bcee4bd
commit fa1b59b7be
2 changed files with 43 additions and 1 deletions

View File

@@ -66,7 +66,7 @@ class UnpackPackReverseInputChannels(FrontReplacementSubgraph):
reverse_channels = ReverseChannels(graph, {
'name': pack.soft_get('name', pack.id) + '/ReverseChannels',
'axis': int64_array(axis), 'order': int64_array([2, 0, 1])}).create_node()
'axis': int64_array(axis), 'order': int64_array([2, 1, 0])}).create_node()
pack.out_port(0).get_connection().set_source(reverse_channels.out_port(0))
unpack.in_port(0).get_connection().set_destination(reverse_channels.in_port(0))

View File

@@ -0,0 +1,42 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.tf.UnpackPackReverseInputChannels import UnpackPackReverseInputChannels
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect_front
nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('unpack', {'op': 'AttributedSplit', 'axis': int64_array(0)}),
**regular_op_with_empty_data('pack', {'op': 'Pack', 'axis': int64_array(0)}),
**result(),
**regular_op_with_empty_data('reverseChannels',
{'op': 'ReverseChannels', 'order': int64_array([2, 1, 0]), 'axis': int64_array(0), 'type': None}),
}
class UnpackPackReverseInputChannelsTest(unittest.TestCase):
def test_replace_to_reverse_channel(self):
graph = build_graph(nodes_attrs=nodes, edges=[
*connect_front('input:0', '0:unpack'),
*connect_front('unpack:0', '2:pack'),
*connect_front('unpack:1', '1:pack'),
*connect_front('unpack:2', '0:pack'),
*connect_front('pack:0', '0:output'),
], nodes_with_edges_only=True)
graph.stage = 'front'
UnpackPackReverseInputChannels().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attrs=nodes, edges=[
*connect_front('input:0', '0:reverseChannels'),
*connect_front('reverseChannels:0', '0:output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)