Fix fusing Multiply node with Convolution in case group != 1 (#1882)

* Fix fusing Multiply node with Convolution in case group != 1

* Add transformation test

* Do not fuse if not possible to reshape const

* Update fuse_linear_ops.py
This commit is contained in:
Maxim Vafin 2020-09-04 20:32:51 +03:00 committed by GitHub
parent 51fa5ab8cb
commit 18a49f9e7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 9 deletions

View File

@ -105,7 +105,14 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
shape = np.append(shape, 1)
mul_val = np.array(value)
value = np.reshape(value, shape)
# If the value fails to reshape to the provided shape, skip fusing.
# This can happen in case of group != 1 of the convolution.
try:
value = np.reshape(value, shape)
except ValueError:
log.error("Cannot fuse const from {} to {}. Reshape failed. Skipping.".format(
node.soft_get('name', node.id),fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True})
return False
# Weights multiplication
mul_name = node.name + '_copy'

View File

@ -22,7 +22,7 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer
from mo.graph.graph import Node
from mo.middle.passes.fusing.fuse_linear_ops import _fuse_mul, fuse_linear_ops
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, const_with_data, connect
nodes_attributes = {
'placeholder_1': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
@ -36,7 +36,7 @@ nodes_attributes = {
'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
# Mul and Add operations
'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True,
'infer': lambda node: eltwise_infer(node, lambda a, b: a*b)},
'infer': lambda node: eltwise_infer(node, lambda a, b: a * b)},
'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
@ -69,6 +69,12 @@ nodes_attributes = {
'const_conv_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
'const_conv_2_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'deconv': {'type': 'Deconvolution', 'kind': 'op', 'op': 'Deconv2D', 'layout': 'NHWC'},
'deconv_w': {'value': None, 'shape': None, 'kind': 'data'},
'deconv_b': {'value': None, 'shape': None, 'kind': 'data'},
'const_deconv_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
'const_deconv_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
'deconv_data': {'value': None, 'shape': None, 'kind': 'data'},
# MatMul
'fc_1': {'type': 'MatMul', 'kind': 'op', 'layout': 'NHWC', 'op': 'MatMul'},
'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
@ -670,7 +676,8 @@ class FuseMulTests(unittest.TestCase):
'const_mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
'mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
'conv_1': {'can_be_fused': False},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},
@ -783,6 +790,46 @@ class FuseMulTests(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1', 'placeholder_1')
self.assertTrue(flag, resp)
# Deconv(w)->Mul(array)
def test_fuse_mul_to_deconv_1(self):
# Placeholder->Deonv->Mul
in_shape = np.array([1, 20, 10, 10])
w_shape = np.array([20, 2, 3, 3])
out_shape = np.array([1, 10, 21, 21])
mul_const = np.array(range(10))
edges = [('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'deconv'),
('const_deconv_w', 'deconv_w'),
('deconv_w', 'deconv'),
('deconv', 'deconv_data'),
('deconv_data', 'mul_1'),
('const_mul_1_w', 'mul_1_w'),
('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_1_data', 'op_output')
]
attr_updates = {'placeholder_1_data': {'shape': in_shape},
'const_conv_1_w': {'shape': w_shape, 'value': np.ones(w_shape)},
'deconv': {'group': 5},
'deconv_w': {'shape': w_shape, 'value': np.ones(w_shape),
'output_channel_dim': 1, 'input_channel_dim': 0,
'dims_number': 4},
'deconv_data': {'shape': out_shape},
'mul_1_data': {'shape': mul_const.shape},
'const_mul_1_w': {'shape': mul_const.shape, 'value': mul_const},
'mul_1_w': {'shape': mul_const.shape, 'value': mul_const},
}
graph = build_graph(nodes_attributes, edges, attr_updates)
# same graph, nothing fused
graph_ref = build_graph(nodes_attributes, edges, attr_updates)
_fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'deconv')], backward=True)
graph.clean_up()
(flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1', 'placeholder_1')
self.assertTrue(flag, resp)
# Unit tests for fuse_linear_ops
class FuseLinOpsTests(unittest.TestCase):
@ -1011,7 +1058,6 @@ class FuseLinOpsTests(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
# TODO: refactor this test
# self.assertTrue(flag, resp)
# Op->Mul(array)-+->Conv(w+b)------+->Concat
# | | => Same('can_be_fused': False)
@ -1090,7 +1136,8 @@ class FuseLinOpsTests(unittest.TestCase):
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},
@ -1098,7 +1145,8 @@ class FuseLinOpsTests(unittest.TestCase):
'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
'conv_2': {'can_be_fused': False},
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},
@ -1191,14 +1239,16 @@ class FuseLinOpsTests(unittest.TestCase):
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},
'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},