Provide GatherND with original layout for inputs and output (#3002)
* Provide GatherND with original layout for inputs and output Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix code review #1 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
4a0170dcb0
commit
e3b879ad3b
@ -28,6 +28,7 @@ extensions/back/GroupedConvWeightsNormalize.py
|
|||||||
extensions/back/insert_compatibility_l2normalization.py
|
extensions/back/insert_compatibility_l2normalization.py
|
||||||
extensions/back/InterpolateReshape.py
|
extensions/back/InterpolateReshape.py
|
||||||
extensions/back/kaldi_remove_memory_output.py
|
extensions/back/kaldi_remove_memory_output.py
|
||||||
|
extensions/back/LayoutChangeForGatherND.py
|
||||||
extensions/back/LeakyReLUMutation.py
|
extensions/back/LeakyReLUMutation.py
|
||||||
extensions/back/LRNToNorm.py
|
extensions/back/LRNToNorm.py
|
||||||
extensions/back/MatMulNormalizer.py
|
extensions/back/MatMulNormalizer.py
|
||||||
|
61
model-optimizer/extensions/back/LayoutChangeForGatherND.py
Normal file
61
model-optimizer/extensions/back/LayoutChangeForGatherND.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Copyright (C) 2018-2020 Intel Corporation
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.ops.transpose import Transpose
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
|
from mo.graph.graph import Graph, Port
|
||||||
|
from mo.back.replacement import BackReplacementPattern
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutChangeForGatherND(BackReplacementPattern):
|
||||||
|
"""
|
||||||
|
Return original layout for inputs and output of GatherND operation
|
||||||
|
since the operation is designed for NHWC layout.
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
force_shape_inference = True
|
||||||
|
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def insert_transpose(graph: Graph, input_port: Port, before_input=True):
|
||||||
|
input_rank = len(input_port.data.get_shape())
|
||||||
|
if input_rank > 3:
|
||||||
|
if before_input:
|
||||||
|
axis_order = np.concatenate((int64_array([0]),
|
||||||
|
int64_array(list(range(2, input_rank))),
|
||||||
|
int64_array([1])))
|
||||||
|
source_node = input_port.get_source().node
|
||||||
|
transpose_name = source_node.soft_get('name', source_node.id) + '/TransposeToNHWC'
|
||||||
|
else:
|
||||||
|
axis_order = np.concatenate(
|
||||||
|
(int64_array([0]),
|
||||||
|
int64_array([input_rank - 1]),
|
||||||
|
int64_array(list(range(1, input_rank - 1)))))
|
||||||
|
transpose_name = input_port.node.soft_get('name', input_port.node.id) + '/TransposeToNCHW'
|
||||||
|
input_port.node['need_shape_inference'] = True
|
||||||
|
input_port.node['override_output_shape'] = True
|
||||||
|
transpose = create_op_with_const_inputs(graph, Transpose, {1: axis_order}, {'name': transpose_name})
|
||||||
|
input_port.get_connection().insert_node(transpose)
|
||||||
|
transpose['need_shape_inference'] = True
|
||||||
|
transpose['override_output_shape'] = True
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
for gathernd in graph.get_op_nodes(type='GatherND'):
|
||||||
|
self.insert_transpose(graph, gathernd.in_port(0), before_input=True)
|
||||||
|
self.insert_transpose(graph, gathernd.in_port(1), before_input=True)
|
||||||
|
self.insert_transpose(graph, gathernd.out_port(0), before_input=False)
|
139
model-optimizer/extensions/back/LayoutChangeForGatherND_test.py
Normal file
139
model-optimizer/extensions/back/LayoutChangeForGatherND_test.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Copyright (C) 2020 Intel Corporation
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.back.LayoutChangeForGatherND import LayoutChangeForGatherND
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from mo.utils.unittest.graph import build_graph
|
||||||
|
|
||||||
|
nodes_attributes = {
|
||||||
|
'placeholder_1': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||||
|
'placeholder_2': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||||
|
# GatherND
|
||||||
|
'gathernd': {'type': 'GatherND', 'kind': 'op', 'op': 'GatherND'},
|
||||||
|
'gathernd_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
# Result layer
|
||||||
|
'result': {'type': 'Result', 'kind': 'op', 'op': 'Result'},
|
||||||
|
# Transpose layers
|
||||||
|
'transpose_1': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'need_shape_inference': True},
|
||||||
|
'transpose_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
'axis_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||||
|
'axis_1_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||||
|
'transpose_2': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'need_shape_inference': True},
|
||||||
|
'transpose_2_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
'axis_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||||
|
'axis_2_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||||
|
'transpose_3': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'need_shape_inference': True},
|
||||||
|
'transpose_3_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||||
|
'axis_3_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||||
|
'axis_3_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutChangeForGatherNDTests(unittest.TestCase):
|
||||||
|
def test_tf_all_ports(self):
|
||||||
|
graph = build_graph(nodes_attributes,
|
||||||
|
[('placeholder_1', 'placeholder_1_data'),
|
||||||
|
('placeholder_2', 'placeholder_2_data'),
|
||||||
|
('placeholder_1_data', 'gathernd'),
|
||||||
|
('placeholder_2_data', 'gathernd'),
|
||||||
|
('gathernd', 'gathernd_data'),
|
||||||
|
('gathernd_data', 'result'),
|
||||||
|
],
|
||||||
|
{'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'placeholder_2_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'gathernd_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
})
|
||||||
|
graph.graph['fw'] = 'tf'
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes_attributes,
|
||||||
|
[('placeholder_1', 'placeholder_1_data'),
|
||||||
|
('placeholder_2', 'placeholder_2_data'),
|
||||||
|
('placeholder_1_data', 'transpose_1'),
|
||||||
|
('axis_1_const', 'axis_1_const_data'),
|
||||||
|
('axis_1_const_data', 'transpose_1'),
|
||||||
|
('transpose_1', 'transpose_1_data'),
|
||||||
|
('placeholder_2_data', 'transpose_2'),
|
||||||
|
('axis_2_const', 'axis_2_const_data'),
|
||||||
|
('axis_2_const_data', 'transpose_2'),
|
||||||
|
('transpose_2', 'transpose_2_data'),
|
||||||
|
('transpose_1_data', 'gathernd'),
|
||||||
|
('transpose_2_data', 'gathernd'),
|
||||||
|
('gathernd', 'gathernd_data'),
|
||||||
|
('gathernd_data', 'transpose_3'),
|
||||||
|
('axis_3_const', 'axis_3_const_data'),
|
||||||
|
('axis_3_const_data', 'transpose_3'),
|
||||||
|
('transpose_3', 'transpose_3_data'),
|
||||||
|
('transpose_3_data', 'result'),
|
||||||
|
],
|
||||||
|
{'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'placeholder_2_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'axis_1_const_data': {'value': int64_array([0, 2, 3, 1])},
|
||||||
|
'axis_2_const_data': {'value': int64_array([0, 2, 3, 1])},
|
||||||
|
'gathernd_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'axis_3_const_data': {'value': int64_array([0, 3, 1, 2])},
|
||||||
|
})
|
||||||
|
|
||||||
|
pattern = LayoutChangeForGatherND()
|
||||||
|
pattern.find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_tf_one_ports(self):
|
||||||
|
graph = build_graph(nodes_attributes,
|
||||||
|
[('placeholder_1', 'placeholder_1_data'),
|
||||||
|
('placeholder_2', 'placeholder_2_data'),
|
||||||
|
('placeholder_1_data', 'gathernd'),
|
||||||
|
('placeholder_2_data', 'gathernd'),
|
||||||
|
('gathernd', 'gathernd_data'),
|
||||||
|
('gathernd_data', 'result'),
|
||||||
|
],
|
||||||
|
{'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'placeholder_2_data': {'shape': np.array([1, 3])},
|
||||||
|
'gathernd_data': {'shape': np.array([1, 3])},
|
||||||
|
})
|
||||||
|
graph.graph['fw'] = 'tf'
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes_attributes,
|
||||||
|
[('placeholder_1', 'placeholder_1_data'),
|
||||||
|
('placeholder_2', 'placeholder_2_data'),
|
||||||
|
('placeholder_1_data', 'transpose_1'),
|
||||||
|
('axis_1_const', 'axis_1_const_data'),
|
||||||
|
('axis_1_const_data', 'transpose_1'),
|
||||||
|
('transpose_1', 'transpose_1_data'),
|
||||||
|
('transpose_1_data', 'gathernd'),
|
||||||
|
('placeholder_2_data', 'gathernd'),
|
||||||
|
('gathernd', 'gathernd_data'),
|
||||||
|
('gathernd_data', 'result'),
|
||||||
|
],
|
||||||
|
{'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
|
||||||
|
'placeholder_2_data': {'shape': np.array([1, 3])},
|
||||||
|
'axis_1_const_data': {'value': int64_array([0, 2, 3, 1])},
|
||||||
|
'gathernd_data': {'shape': np.array([1, 3])}
|
||||||
|
})
|
||||||
|
|
||||||
|
pattern = LayoutChangeForGatherND()
|
||||||
|
pattern.find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
Loading…
Reference in New Issue
Block a user