Fix fusing Multiply node with Convolution in case group != 1 (#1882)
* Fix fusing Multiply node with Convolution in case group != 1 * Add transformation test * Do not fuse if not possible to reshape const * Update fuse_linear_ops.py
This commit is contained in:
parent
51fa5ab8cb
commit
18a49f9e7e
@ -105,7 +105,14 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
|
||||
shape = np.append(shape, 1)
|
||||
|
||||
mul_val = np.array(value)
|
||||
value = np.reshape(value, shape)
|
||||
# If the value fails to reshape to the provided shape, skip fusing.
|
||||
# This can happen in case of group != 1 of the convolution.
|
||||
try:
|
||||
value = np.reshape(value, shape)
|
||||
except ValueError:
|
||||
log.error("Cannot fuse const from {} to {}. Reshape failed. Skipping.".format(
|
||||
node.soft_get('name', node.id),fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True})
|
||||
return False
|
||||
|
||||
# Weights multiplication
|
||||
mul_name = node.name + '_copy'
|
||||
|
@ -22,7 +22,7 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer
|
||||
from mo.graph.graph import Node
|
||||
from mo.middle.passes.fusing.fuse_linear_ops import _fuse_mul, fuse_linear_ops
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, const_with_data, connect
|
||||
|
||||
nodes_attributes = {
|
||||
'placeholder_1': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
@ -36,7 +36,7 @@ nodes_attributes = {
|
||||
'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
# Mul and Add operations
|
||||
'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True,
|
||||
'infer': lambda node: eltwise_infer(node, lambda a, b: a*b)},
|
||||
'infer': lambda node: eltwise_infer(node, lambda a, b: a * b)},
|
||||
'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
|
||||
'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
@ -69,6 +69,12 @@ nodes_attributes = {
|
||||
'const_conv_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
|
||||
'const_conv_2_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
|
||||
'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'deconv': {'type': 'Deconvolution', 'kind': 'op', 'op': 'Deconv2D', 'layout': 'NHWC'},
|
||||
'deconv_w': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'deconv_b': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'const_deconv_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
|
||||
'const_deconv_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None, 'op': 'Const'},
|
||||
'deconv_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
# MatMul
|
||||
'fc_1': {'type': 'MatMul', 'kind': 'op', 'layout': 'NHWC', 'op': 'MatMul'},
|
||||
'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
@ -670,7 +676,8 @@ class FuseMulTests(unittest.TestCase):
|
||||
'const_mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
|
||||
'mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
|
||||
'conv_1': {'can_be_fused': False},
|
||||
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
|
||||
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
|
||||
'value': np.ones((11, 11, 3, 96))},
|
||||
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
|
||||
'output_channel_dim': 3, 'input_channel_dim': 2,
|
||||
'dims_number': 4},
|
||||
@ -783,6 +790,46 @@ class FuseMulTests(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1', 'placeholder_1')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
# Deconv(w)->Mul(array)
|
||||
def test_fuse_mul_to_deconv_1(self):
|
||||
# Placeholder->Deonv->Mul
|
||||
in_shape = np.array([1, 20, 10, 10])
|
||||
w_shape = np.array([20, 2, 3, 3])
|
||||
out_shape = np.array([1, 10, 21, 21])
|
||||
mul_const = np.array(range(10))
|
||||
|
||||
edges = [('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'deconv'),
|
||||
('const_deconv_w', 'deconv_w'),
|
||||
('deconv_w', 'deconv'),
|
||||
('deconv', 'deconv_data'),
|
||||
('deconv_data', 'mul_1'),
|
||||
('const_mul_1_w', 'mul_1_w'),
|
||||
('mul_1_w', 'mul_1'),
|
||||
('mul_1', 'mul_1_data'),
|
||||
('mul_1_data', 'op_output')
|
||||
]
|
||||
attr_updates = {'placeholder_1_data': {'shape': in_shape},
|
||||
'const_conv_1_w': {'shape': w_shape, 'value': np.ones(w_shape)},
|
||||
'deconv': {'group': 5},
|
||||
'deconv_w': {'shape': w_shape, 'value': np.ones(w_shape),
|
||||
'output_channel_dim': 1, 'input_channel_dim': 0,
|
||||
'dims_number': 4},
|
||||
'deconv_data': {'shape': out_shape},
|
||||
'mul_1_data': {'shape': mul_const.shape},
|
||||
'const_mul_1_w': {'shape': mul_const.shape, 'value': mul_const},
|
||||
'mul_1_w': {'shape': mul_const.shape, 'value': mul_const},
|
||||
}
|
||||
graph = build_graph(nodes_attributes, edges, attr_updates)
|
||||
# same graph, nothing fused
|
||||
graph_ref = build_graph(nodes_attributes, edges, attr_updates)
|
||||
|
||||
_fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'deconv')], backward=True)
|
||||
graph.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1', 'placeholder_1')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
# Unit tests for fuse_linear_ops
|
||||
class FuseLinOpsTests(unittest.TestCase):
|
||||
@ -1011,7 +1058,6 @@ class FuseLinOpsTests(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
|
||||
# TODO: refactor this test
|
||||
# self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
# Op->Mul(array)-+->Conv(w+b)------+->Concat
|
||||
# | | => Same('can_be_fused': False)
|
||||
@ -1090,7 +1136,8 @@ class FuseLinOpsTests(unittest.TestCase):
|
||||
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
|
||||
'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
|
||||
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
|
||||
'value': np.ones((11, 11, 3, 96))},
|
||||
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
|
||||
'output_channel_dim': 3, 'input_channel_dim': 2,
|
||||
'dims_number': 4},
|
||||
@ -1098,7 +1145,8 @@ class FuseLinOpsTests(unittest.TestCase):
|
||||
'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
|
||||
'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
|
||||
'conv_2': {'can_be_fused': False},
|
||||
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
|
||||
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]),
|
||||
'value': np.ones((11, 11, 3, 96))},
|
||||
'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
|
||||
'output_channel_dim': 3, 'input_channel_dim': 2,
|
||||
'dims_number': 4},
|
||||
@ -1191,14 +1239,16 @@ class FuseLinOpsTests(unittest.TestCase):
|
||||
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
|
||||
'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
|
||||
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
|
||||
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
|
||||
'value': np.ones((11, 11, 3, 96))},
|
||||
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
|
||||
'output_channel_dim': 3, 'input_channel_dim': 2,
|
||||
'dims_number': 4},
|
||||
'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
|
||||
'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
|
||||
'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
|
||||
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
|
||||
'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]),
|
||||
'value': np.ones((11, 11, 3, 96))},
|
||||
'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
|
||||
'output_channel_dim': 3, 'input_channel_dim': 2,
|
||||
'dims_number': 4},
|
||||
|
Loading…
Reference in New Issue
Block a user