Fixed tensors propagation in GlobalPoolingToReduce. (#8835)
This commit is contained in:
parent
2d0ae6028a
commit
b3869cb127
@ -45,8 +45,8 @@ class GlobalPoolingToReduce(FrontReplacementPattern):
|
|||||||
|
|
||||||
pooling.out_port(0).get_connection().set_source(reduce.out_port(0))
|
pooling.out_port(0).get_connection().set_source(reduce.out_port(0))
|
||||||
src = pooling.in_port(0).get_connection().get_source()
|
src = pooling.in_port(0).get_connection().get_source()
|
||||||
pooling.in_port(0).disconnect()
|
|
||||||
src.connect(reduce.in_port(0))
|
reduce.in_port(0).get_connection().set_source(src)
|
||||||
|
|
||||||
start = Const(graph, {'value': int64_array(2)}).create_node()
|
start = Const(graph, {'value': int64_array(2)}).create_node()
|
||||||
end = Rank(graph, {'name': name + '/input_rank'}).create_node()
|
end = Rank(graph, {'name': name + '/input_rank'}).create_node()
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from extensions.front.global_pooling_to_reduce import GlobalPoolingToReduce
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.graph.graph import Node
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from unit_tests.utils.graph import build_graph, regular_op, result, build_graph_with_edge_attrs, const
|
||||||
|
|
||||||
|
nodes = {**regular_op('input', {'type': 'Parameter'}),
|
||||||
|
|
||||||
|
**regular_op('relu', {'type': 'Relu'}),
|
||||||
|
**regular_op('pooling', {'type': 'Pooling', 'global_pool': True, 'pool_method': 'avg'}),
|
||||||
|
|
||||||
|
**result('result'),
|
||||||
|
|
||||||
|
**regular_op('rank', {'type': 'Rank'}),
|
||||||
|
**regular_op('reduce_mean', {'type': 'ReduceMean'}),
|
||||||
|
**regular_op('range', {'type': 'Range'}),
|
||||||
|
**const('const_1', int64_array(2)),
|
||||||
|
**const('const_2', int64_array(1)),
|
||||||
|
|
||||||
|
}
|
||||||
|
edges = [('input', 'relu', {'in': 0, 'out': 0}), ('relu', 'pooling', {'in': 0, 'out': 0}),
|
||||||
|
('pooling', 'result', {'in': 0, 'out': 0})]
|
||||||
|
ref_edges = [('input', 'relu', {'in': 0, 'out': 0}), ('relu', 'rank', {'in': 0, 'out': 0}),
|
||||||
|
('rank', 'range', {'in': 1, 'out': 0}),
|
||||||
|
('relu', 'reduce_mean', {'in': 0, 'out': 0}),
|
||||||
|
('const_1', 'range', {'in': 0, 'out': 0}), ('const_2', 'range', {'in': 2, 'out': 0}),
|
||||||
|
('range', 'reduce_mean', {'in': 1, 'out': 0}),
|
||||||
|
('reduce_mean', 'result', {'in': 0, 'out': 0})]
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalPoolingToReduceTest(unittest.TestCase):
|
||||||
|
def test_global_pooling_to_reduce(self):
|
||||||
|
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes, ref_edges)
|
||||||
|
graph.stage = 'front'
|
||||||
|
graph.graph['layout'] = 'NCHW'
|
||||||
|
node = Node(graph, 'relu')
|
||||||
|
node.out_edge(0)['fw_tensor_debug_info'] = [('Relu_0', 'Relu_tensor')]
|
||||||
|
|
||||||
|
GlobalPoolingToReduce().find_and_replace_pattern(graph)
|
||||||
|
graph.clean_up()
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
node = Node(graph, 'relu')
|
||||||
|
edge_attrs = node.out_port(0).get_destinations()[0].get_in_edge_attrs()
|
||||||
|
self.assertTrue('fw_tensor_debug_info' in edge_attrs)
|
||||||
|
self.assertTrue(edge_attrs['fw_tensor_debug_info'] == [('Relu_0', 'Relu_tensor')])
|
Loading…
Reference in New Issue
Block a user