Fix incorrect working UnpackPackReverseInputChannels for centernet (#9201)
* fix UnpackPackReverseInputChannels * Add UnpackPackReverseInputChannels test
This commit is contained in:
committed by
GitHub
parent
a83bcee4bd
commit
fa1b59b7be
@@ -66,7 +66,7 @@ class UnpackPackReverseInputChannels(FrontReplacementSubgraph):
|
||||
|
||||
reverse_channels = ReverseChannels(graph, {
|
||||
'name': pack.soft_get('name', pack.id) + '/ReverseChannels',
|
||||
'axis': int64_array(axis), 'order': int64_array([2, 0, 1])}).create_node()
|
||||
'axis': int64_array(axis), 'order': int64_array([2, 1, 0])}).create_node()
|
||||
|
||||
pack.out_port(0).get_connection().set_source(reverse_channels.out_port(0))
|
||||
unpack.in_port(0).get_connection().set_destination(reverse_channels.in_port(0))
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.tf.UnpackPackReverseInputChannels import UnpackPackReverseInputChannels
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect_front
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
|
||||
**regular_op_with_empty_data('unpack', {'op': 'AttributedSplit', 'axis': int64_array(0)}),
|
||||
**regular_op_with_empty_data('pack', {'op': 'Pack', 'axis': int64_array(0)}),
|
||||
**result(),
|
||||
|
||||
**regular_op_with_empty_data('reverseChannels',
|
||||
{'op': 'ReverseChannels', 'order': int64_array([2, 1, 0]), 'axis': int64_array(0), 'type': None}),
|
||||
}
|
||||
|
||||
|
||||
class UnpackPackReverseInputChannelsTest(unittest.TestCase):
|
||||
def test_replace_to_reverse_channel(self):
|
||||
graph = build_graph(nodes_attrs=nodes, edges=[
|
||||
*connect_front('input:0', '0:unpack'),
|
||||
*connect_front('unpack:0', '2:pack'),
|
||||
*connect_front('unpack:1', '1:pack'),
|
||||
*connect_front('unpack:2', '0:pack'),
|
||||
*connect_front('pack:0', '0:output'),
|
||||
], nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
|
||||
UnpackPackReverseInputChannels().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attrs=nodes, edges=[
|
||||
*connect_front('input:0', '0:reverseChannels'),
|
||||
*connect_front('reverseChannels:0', '0:output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
Reference in New Issue
Block a user