Fix for Reduce extractors and normalizer (#3136)
* Fixed extractor for ONNX ReduceXXX operations and fixed ReduceAxisNormalizer transformation * Unit test for ReduceAxisNormalizer transformation
This commit is contained in:
@@ -22,7 +22,9 @@ from mo.graph.graph import Node
|
||||
|
||||
|
||||
def update_reduce_node_attrs_with(node: Node, c: callable):
|
||||
axis = onnx_attr(node, 'axes', 'ints', default=None, dst_type=lambda x: int64_array(x))
|
||||
axis = onnx_attr(node, 'axes', 'ints', default=None)
|
||||
if axis is not None:
|
||||
axis = int64_array(axis)
|
||||
keep_dims = onnx_attr(node, 'keepdims', 'i', default=True)
|
||||
c.update_node_stat(node, {'axis': axis, 'keep_dims': keep_dims})
|
||||
|
||||
|
||||
@@ -17,22 +17,21 @@
|
||||
from extensions.ops.ReduceOps import reduce_map
|
||||
from extensions.ops.range import Range
|
||||
from extensions.ops.rank import Rank
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class ReduceAxisNormalizer(FrontReplacementSubgraph):
|
||||
"""
|
||||
Reduce operation requires information about axis, that is represented in original frameworks differently:
|
||||
- by layer parameter
|
||||
- by 1-port input value
|
||||
|
||||
ReduceAxisNormalizer reforms Reduce operations to store axis info in 1-port input.
|
||||
Reduce operation requires information about axis, that is represented in original frameworks differently: as an
|
||||
operation attribute or as a 1-st input port value. ReduceAxisNormalizer adds second input to Reduce operations with
|
||||
axes to normalize if axes are specified as an attribute.
|
||||
"""
|
||||
enabled = True
|
||||
force_shape_inference = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
@@ -50,23 +49,18 @@ class ReduceAxisNormalizer(FrontReplacementSubgraph):
|
||||
|
||||
# if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
|
||||
# to None. The infer function handles this case because the input shape is known at this stage only
|
||||
if node.has('axis'):
|
||||
if node.has_valid('axis'):
|
||||
const = Const(graph, {'name': node_name + '/axis', 'value': node.axis}).create_node()
|
||||
node.add_input_port(1, skip_if_exist=True)
|
||||
const.out_port(0).connect(node.in_port(1))
|
||||
del graph.node[node.id]['axis']
|
||||
else:
|
||||
# The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
|
||||
|
||||
begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node()
|
||||
step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node()
|
||||
end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node()
|
||||
axes = Range(graph, dict(name=node_name + '/axes_')).create_node()
|
||||
|
||||
begin_of_range.out_port(0).connect(axes.in_port(0))
|
||||
axes = create_op_with_const_inputs(graph, Range, {0: int64_array(0), 2: int64_array(1)},
|
||||
dict(name=node_name + '/axes'))
|
||||
end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node()
|
||||
node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
|
||||
end_of_range.out_port(0).connect(axes.in_port(1))
|
||||
step.out_port(0).connect(axes.in_port(2))
|
||||
|
||||
node.add_input_port(1, skip_if_exist=True)
|
||||
axes.out_port(0).connect(node.in_port(1))
|
||||
node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.reduce_axis_normalizer import ReduceAxisNormalizer
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, connect_front, regular_op
|
||||
|
||||
nodes = {
|
||||
**regular_op('parameter', {'type': 'Parameter'}),
|
||||
**regular_op('reduce', {'op': 'ReduceSum', 'axis': None}),
|
||||
**regular_op('axis', {'op': 'Const', 'type': 'Const', 'value': int64_array([1])}),
|
||||
**result(),
|
||||
}
|
||||
|
||||
edges = [
|
||||
*connect_front('parameter:0', '0:reduce'),
|
||||
*connect_front('reduce', 'output'),
|
||||
]
|
||||
|
||||
|
||||
class ReduceAxisNormalizerTest(unittest.TestCase):
|
||||
def test_reduce_axis_is_None(self):
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
|
||||
ReduceAxisNormalizer().find_and_replace_pattern(graph)
|
||||
|
||||
ref_nodes = nodes.copy()
|
||||
ref_nodes.update({**regular_op('rank', {'op': 'Rank', 'type': None}),
|
||||
**regular_op('range', {'op': 'Range', 'type': 'Range'}),
|
||||
**regular_op('begin', {'type': 'Const', 'value': int64_array([0])}),
|
||||
**regular_op('step', {'type': 'Const', 'value': int64_array([1])}),
|
||||
})
|
||||
graph_ref = build_graph(ref_nodes, [
|
||||
*edges,
|
||||
*connect_front('parameter:0', 'rank'),
|
||||
*connect_front('begin:0', '0:range'),
|
||||
*connect_front('rank:0', '1:range'),
|
||||
*connect_front('step:0', '2:range'),
|
||||
*connect_front('range:0', '1:reduce'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_reduce_axis_is_const(self):
|
||||
graph = build_graph(nodes, edges, {'reduce': {'axis': 1}}, nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*edges,
|
||||
*connect_front('axis', '1:reduce'),
|
||||
], {'axis': {'value': np.int64(1)}}, nodes_with_edges_only=True)
|
||||
|
||||
ReduceAxisNormalizer().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
Reference in New Issue
Block a user