Update L2NormToNorm transformation (#4154)
* Removechanges added by mistake * Update transformation * Refactor tests, add more cases * Rename variable * Refactor and rename transformation * Update tests, add more cases
This commit is contained in:
parent
14cd2d52dd
commit
c1a606d507
@ -564,7 +564,7 @@ extensions/middle/InputCut.py
|
||||
extensions/middle/InsertLayoutPropagationTransposes.py
|
||||
extensions/middle/InsertSelect.py
|
||||
extensions/middle/InterpolateSequenceToInterpolate.py
|
||||
extensions/middle/L2NormToNorm.py
|
||||
extensions/middle/L2NormFusing.py
|
||||
extensions/middle/LayoutChangeForConstantShapePaths.py
|
||||
extensions/middle/LeakyReluPattern.py
|
||||
extensions/middle/LSTMRNNSequenceToTensorIterator.py
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,7 +18,7 @@ import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.normalize import NormalizeOp
|
||||
from extensions.ops.normalize_l2 import NormalizeL2Op
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
@ -82,16 +82,31 @@ class L2NormToNorm(MiddleReplacementPattern):
|
||||
log.debug('The value of the "maximum_y_data" is not defined or is not constant')
|
||||
return
|
||||
|
||||
# We need to check axes which performed reduction because IE supports only 2D, 3D, 4D inputs and
|
||||
# reduction only along spatial and channel dimensions.
|
||||
input_rank = len(match['sum'].in_port(0).data.get_shape())
|
||||
if input_rank not in [2, 3, 4]:
|
||||
log.debug('IE supports L2 normalization only for 2D, 3D and 4D tensors, skip fusing transformation.')
|
||||
return
|
||||
|
||||
axes = match['sum'].in_port(1).data.get_value()
|
||||
axes = int64_array(axes)
|
||||
if axes.shape == ():
|
||||
axes = int64_array([axes])
|
||||
axes.sort()
|
||||
|
||||
if not np.array_equal(axes, int64_array(np.arange(start=1, stop=input_rank))):
|
||||
log.debug('IE doesn\'t support l2 normalization with reduction along axes {}, skip fusing transformation.'
|
||||
''.format(axes))
|
||||
return
|
||||
|
||||
# rename l2_normalize node since it will be no longer output after the transformation
|
||||
output_name = match['l2_normalize'].soft_get('name', match['l2_normalize'].id)
|
||||
normalizel2_name = output_name + '/normalizel2'
|
||||
rename_node(match['l2_normalize'], normalizel2_name)
|
||||
|
||||
normalize_node = create_op_node_with_second_input(graph, NormalizeOp,
|
||||
np.ones(shape=int64_array([match['input'].shape[-1]]),
|
||||
dtype=match['input'].data_type),
|
||||
{'name': output_name, 'eps': y,
|
||||
'across_spatial': 0, 'channel_shared': 0})
|
||||
normalize_node = create_op_node_with_second_input(graph, NormalizeL2Op, axes, {'name': output_name,
|
||||
'eps_mode': 'max', 'eps': y})
|
||||
rename_node(normalize_node, output_name)
|
||||
|
||||
match['square'].in_port(0).get_source().connect(normalize_node.in_port(0))
|
420
model-optimizer/extensions/middle/L2NormFusing_test.py
Normal file
420
model-optimizer/extensions/middle/L2NormFusing_test.py
Normal file
@ -0,0 +1,420 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.L2NormFusing import L2NormToNorm
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph_with_attrs
|
||||
|
||||
# A list with nodes attributes used to build various graphs.
|
||||
nodes = [
|
||||
('l2_normalize_mul', dict(kind='op', op='Mul', name='l2_norm_name')),
|
||||
('l2_normalize_mul_data', dict(kind='data')),
|
||||
('maximum', dict(kind='op', op='Maximum')),
|
||||
('maximum_data', dict(kind='data')),
|
||||
('maximum_y_const', dict(kind='op', op='Const', value=np.array(12.e-13, dtype=np.float32))),
|
||||
('maximum_y_data', dict(kind='data', value=np.array(12.e-13, dtype=np.float32))),
|
||||
('rsqrt_pow', dict(kind='data', value=-0.5)),
|
||||
('rsqrt', dict(kind='op', op='Pow')),
|
||||
('rsqrt_data', dict(kind='data')),
|
||||
('square_pow', dict(kind='op', op='Const', value=2.)),
|
||||
('square_pow_data', dict(kind='data', value=2.)),
|
||||
('square', dict(kind='op', op='Pow')),
|
||||
('sum', dict(kind='op', op='ReduceSum')),
|
||||
('sum_data', dict(kind='data')),
|
||||
('sum_axes', dict(kind='op', op='Const')),
|
||||
# nodes added after replacement
|
||||
('normalize_node', dict(kind='op', op='NormalizeL2')),
|
||||
('weights_node', dict(kind='op', op='Const')),
|
||||
('result', dict(kind='op', op='Result'))
|
||||
]
|
||||
|
||||
edges = [
|
||||
('input', 'input_data', {'out': 0}),
|
||||
('input_data', 'square', {'in': 0}),
|
||||
('square_pow', 'square_pow_data', {'out': 0}),
|
||||
('square_pow_data', 'square', {'in': 1}),
|
||||
('square', 'square_data'),
|
||||
('square_data', 'sum'),
|
||||
('sum_axes', 'sum_axes_data'),
|
||||
('sum_axes_data', 'sum'),
|
||||
('sum', 'sum_data'),
|
||||
('maximum_y_const', 'maximum_y_data'),
|
||||
('maximum_y_data', 'maximum'),
|
||||
('sum_data', 'maximum'),
|
||||
('maximum', 'maximum_data'),
|
||||
('maximum_data', 'rsqrt', {'in': 0}),
|
||||
('rsqrt_pow', 'rsqrt', {'in': 1}),
|
||||
('rsqrt', 'rsqrt_data'),
|
||||
('rsqrt_data', 'l2_normalize_mul'),
|
||||
('input_data', 'l2_normalize_mul'),
|
||||
('l2_normalize_mul', 'l2_normalize_mul_data'),
|
||||
('l2_normalize_mul_data', 'result'),
|
||||
]
|
||||
|
||||
edges_after_replacement = [
|
||||
('input', 'input_data', {'out': 0}),
|
||||
('input_data', 'normalize_node'),
|
||||
('weights_node', 'weights_node_data'),
|
||||
('weights_node_data', 'normalize_node'),
|
||||
('normalize_node', 'l2_normalize_mul_data'),
|
||||
('l2_normalize_mul_data', 'result'),
|
||||
]
|
||||
|
||||
|
||||
class L2NormToNormTest(unittest.TestCase):
|
||||
def test_2D(self):
|
||||
input_shape = int64_array([1, 300])
|
||||
axes = int64_array([1])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=axes.shape)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('weights_node_data', dict(kind='data', value=axes.sort())),
|
||||
], edges_after_replacement, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_2D_scalar_axis(self):
|
||||
input_shape = int64_array([1, 300])
|
||||
axes = int64_array(1)
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('weights_node_data', dict(kind='data', value=int64_array([axes]).sort())),
|
||||
], edges_after_replacement, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_3D(self):
|
||||
input_shape = int64_array([1, 300, 300])
|
||||
axes = int64_array([1, 2])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('weights_node_data', dict(kind='data', value=axes.sort())),
|
||||
], edges_after_replacement, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([1, 2, 3])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('weights_node_data', dict(kind='data', value=axes.sort())),
|
||||
], edges_after_replacement, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D_mixed_axes(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([3, 1, 2])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('weights_node_data', dict(kind='data', value=axes.sort())),
|
||||
], edges_after_replacement, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D_multiple_consumers(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([1, 2, 3])
|
||||
weights_value = np.ones(shape=int64_array([input_shape[-1]]), dtype=np.float32)
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
('result_2', dict(kind='op', op='Result'))
|
||||
], edges + [('input_data', 'result_2')], nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('weights_node_data', dict(kind='data', value=axes.sort())),
|
||||
('result_2', dict(kind='op', op='Result'))
|
||||
], edges_after_replacement + [('input_data', 'result_2')], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_1D_negative(self):
|
||||
input_shape = int64_array([300])
|
||||
axes = int64_array([0])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_2D_negative(self):
|
||||
input_shape = int64_array([1, 300])
|
||||
axes = int64_array([0])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_3D_negative(self):
|
||||
input_shape = int64_array([1, 300, 300])
|
||||
axes = int64_array([2])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D_negative_1(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([0, 1, 2])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D_negative_2(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([2])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D_negative_3(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([2, 1])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_4D_negative_4(self):
|
||||
input_shape = int64_array([1, 300, 300, 3])
|
||||
axes = int64_array([2, 0])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_5D_negative(self):
|
||||
input_shape = int64_array([1, 300, 300, 300, 3])
|
||||
axes = int64_array([1, 2, 3, 4])
|
||||
|
||||
graph = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [
|
||||
('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
|
||||
('square_data', dict(kind='data', shape=input_shape)),
|
||||
('sum_axes_data', dict(kind='data', value=axes, shape=None)),
|
||||
], edges, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -1,109 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.L2NormToNorm import L2NormToNorm
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph_with_attrs
|
||||
|
||||
shape = (1, 300, 300, 3)
|
||||
weights_value = np.array([1.0, 1.0, 1.0])
|
||||
|
||||
# A list with nodes attributes used to build various graphs.
|
||||
nodes = [
|
||||
('input', dict(kind='op', shape=shape, op='Parameter', data_type=np.float32)),
|
||||
('input_data', dict(kind='data', shape=shape, data_type=np.float32)),
|
||||
('l2_normalize', dict(kind='op', op='Mul', name='l2_norm_name')),
|
||||
('l2_normalize_data', dict(kind='data')),
|
||||
('maximum', dict(kind='op', op='Maximum')),
|
||||
('maximum_data', dict(kind='data')),
|
||||
('maximum_y_const', dict(kind='op', op='Const', value=np.array(12.e-13, dtype=np.float32))),
|
||||
('maximum_y_data', dict(kind='data', value=np.array(12.e-13, dtype=np.float32))),
|
||||
('rsqrt_pow', dict(kind='data', value=-0.5)),
|
||||
('rsqrt', dict(kind='op', op='Pow')),
|
||||
('rsqrt_data', dict(kind='data')),
|
||||
('square_pow', dict(kind='op', op='Const', value=2.)),
|
||||
('square_pow_data', dict(kind='data', value=2.)),
|
||||
('square', dict(kind='op', op='Pow')),
|
||||
('square_data', dict(kind='data')),
|
||||
('sum', dict(kind='op', op='ReduceSum')),
|
||||
('sum_data', dict(kind='data')),
|
||||
# nodes added after replacement
|
||||
('normalize_node', dict(kind='op', op='Normalize')),
|
||||
('weights_node', dict(kind='op', op='Const', shape=weights_value.shape, value=weights_value)),
|
||||
('weights_node_data', dict(kind='data', op='Const')),
|
||||
('result', dict(kind='op', op='Result'))
|
||||
]
|
||||
|
||||
edges = [
|
||||
('input', 'input_data', {'out': 0}),
|
||||
('input_data', 'square', {'in': 0}),
|
||||
('square_pow', 'square_pow_data', {'out': 0}),
|
||||
('square_pow_data', 'square', {'in': 1}),
|
||||
('square', 'square_data'),
|
||||
('square_data', 'sum'),
|
||||
('sum', 'sum_data'),
|
||||
('maximum_y_const', 'maximum_y_data'),
|
||||
('maximum_y_data', 'maximum'),
|
||||
('sum_data', 'maximum'),
|
||||
('maximum', 'maximum_data'),
|
||||
('maximum_data', 'rsqrt', {'in': 0}),
|
||||
('rsqrt_pow', 'rsqrt', {'in': 1}),
|
||||
('rsqrt', 'rsqrt_data'),
|
||||
('rsqrt_data', 'l2_normalize'),
|
||||
('input_data', 'l2_normalize'),
|
||||
('l2_normalize', 'l2_normalize_data'),
|
||||
('l2_normalize_data', 'result'),
|
||||
]
|
||||
|
||||
edges_after_replacement = [
|
||||
('input', 'input_data', {'out': 0}),
|
||||
('input_data', 'normalize_node'),
|
||||
('weights_node', 'weights_node_data'),
|
||||
('weights_node_data', 'normalize_node'),
|
||||
('normalize_node', 'l2_normalize_data'),
|
||||
('l2_normalize_data', 'result'),
|
||||
]
|
||||
|
||||
|
||||
class L2NormToNormTest(unittest.TestCase):
|
||||
def test_single_consumer(self):
|
||||
graph = build_graph_with_attrs(nodes, edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes, edges_after_replacement, nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='Normalize')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_multiple_consumers(self):
|
||||
graph = build_graph_with_attrs(nodes + [('result_2', dict(kind='op', op='Result'))],
|
||||
edges + [('input_data', 'result_2')], nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
L2NormToNorm().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes + [('result_2', dict(kind='op', op='Result'))],
|
||||
edges_after_replacement+ [('input_data', 'result_2')],
|
||||
nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='Normalize')[0]]['name'] == 'l2_norm_name')
|
||||
self.assertTrue(flag, resp)
|
Loading…
Reference in New Issue
Block a user