Fix for Reduce extractors and normalizer (#3136)

* Fixed extractor for ONNX ReduceXXX operations and fixed ReduceAxisNormalizer transformation

* Unit test for ReduceAxisNormalizer transformation
This commit is contained in:
Evgeny Lazarev
2020-11-16 18:50:13 +03:00
committed by GitHub
parent 6efcdb0a21
commit 0a9d883d78
3 changed files with 89 additions and 17 deletions

View File

@@ -22,7 +22,9 @@ from mo.graph.graph import Node
def update_reduce_node_attrs_with(node: Node, c: callable):
axis = onnx_attr(node, 'axes', 'ints', default=None, dst_type=lambda x: int64_array(x))
axis = onnx_attr(node, 'axes', 'ints', default=None)
if axis is not None:
axis = int64_array(axis)
keep_dims = onnx_attr(node, 'keepdims', 'i', default=True)
c.update_node_stat(node, {'axis': axis, 'keep_dims': keep_dims})

View File

@@ -17,22 +17,21 @@
from extensions.ops.ReduceOps import reduce_map
from extensions.ops.range import Range
from extensions.ops.rank import Rank
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const
class ReduceAxisNormalizer(FrontReplacementSubgraph):
"""
Reduce operation requires information about axis, that is represented in original frameworks differently:
- by layer parameter
- by 1-port input value
ReduceAxisNormalizer reforms Reduce operations to store axis info in 1-port input.
Reduce operation requires information about axis, that is represented in original frameworks differently: as an
operation attribute or as a 1-st input port value. ReduceAxisNormalizer adds second input to Reduce operations with
axes to normalize if axes are specified as an attribute.
"""
enabled = True
force_shape_inference = True
def pattern(self):
return dict(
@@ -50,23 +49,18 @@ class ReduceAxisNormalizer(FrontReplacementSubgraph):
# if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
# to None. The infer function handles this case because the input shape is known at this stage only
if node.has('axis'):
if node.has_valid('axis'):
const = Const(graph, {'name': node_name + '/axis', 'value': node.axis}).create_node()
node.add_input_port(1, skip_if_exist=True)
const.out_port(0).connect(node.in_port(1))
del graph.node[node.id]['axis']
else:
# The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node()
step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node()
end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node()
axes = Range(graph, dict(name=node_name + '/axes_')).create_node()
begin_of_range.out_port(0).connect(axes.in_port(0))
axes = create_op_with_const_inputs(graph, Range, {0: int64_array(0), 2: int64_array(1)},
dict(name=node_name + '/axes'))
end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node()
node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
end_of_range.out_port(0).connect(axes.in_port(1))
step.out_port(0).connect(axes.in_port(2))
node.add_input_port(1, skip_if_exist=True)
axes.out_port(0).connect(node.in_port(1))
node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))

View File

@@ -0,0 +1,76 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.reduce_axis_normalizer import ReduceAxisNormalizer
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, connect_front, regular_op
nodes = {
**regular_op('parameter', {'type': 'Parameter'}),
**regular_op('reduce', {'op': 'ReduceSum', 'axis': None}),
**regular_op('axis', {'op': 'Const', 'type': 'Const', 'value': int64_array([1])}),
**result(),
}
edges = [
*connect_front('parameter:0', '0:reduce'),
*connect_front('reduce', 'output'),
]
class ReduceAxisNormalizerTest(unittest.TestCase):
def test_reduce_axis_is_None(self):
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
graph.stage = 'front'
ReduceAxisNormalizer().find_and_replace_pattern(graph)
ref_nodes = nodes.copy()
ref_nodes.update({**regular_op('rank', {'op': 'Rank', 'type': None}),
**regular_op('range', {'op': 'Range', 'type': 'Range'}),
**regular_op('begin', {'type': 'Const', 'value': int64_array([0])}),
**regular_op('step', {'type': 'Const', 'value': int64_array([1])}),
})
graph_ref = build_graph(ref_nodes, [
*edges,
*connect_front('parameter:0', 'rank'),
*connect_front('begin:0', '0:range'),
*connect_front('rank:0', '1:range'),
*connect_front('step:0', '2:range'),
*connect_front('range:0', '1:reduce'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_reduce_axis_is_const(self):
graph = build_graph(nodes, edges, {'reduce': {'axis': 1}}, nodes_with_edges_only=True)
graph.stage = 'front'
graph_ref = build_graph(nodes, [
*edges,
*connect_front('axis', '1:reduce'),
], {'axis': {'value': np.int64(1)}}, nodes_with_edges_only=True)
ReduceAxisNormalizer().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)