Group Convolution: prevent dyn input shape from making weights dynamic via ShapeOf (#13374)
* GroupConvolution: prevent dynamic input shape to make weights dynamic via ShapeOf * Minor fix to get rid of deprecation warnings
This commit is contained in:
parent
37dbcdac3f
commit
abd58eef5c
@ -42,55 +42,23 @@ def resolve_convolution_with_group(node: Node, group: int, ir_version: str):
|
||||
reshape = create_op_node_with_second_input(node.graph, Reshape, new_shape,
|
||||
{'override_output_shape': True})
|
||||
elif ir_version == 'V10':
|
||||
I = input_shape[1]
|
||||
if is_fully_defined(weights_shape[2:]) and is_fully_defined(I):
|
||||
new_shape = shape_array([group, node.output // group, I // group, *weights_shape[2:]])
|
||||
assert np.prod(weights_shape) == np.prod(new_shape), 'Initial weights shape {}, grouped weights shape {}' \
|
||||
''.format(weights_shape, new_shape)
|
||||
reshape = create_op_node_with_second_input(node.graph, Reshape, new_shape,
|
||||
{'override_output_shape': True, 'special_zero': True})
|
||||
else:
|
||||
# if weights/or input channel dimension is dynamic need to to compute new_shape in a new subgraph
|
||||
weights_node = node.in_port(1).get_source().node
|
||||
input_node = node.in_port(0).get_source().node
|
||||
# Concat([Constant([group, node.output // group, -1]), *weights_shape[2:]], axis=1)
|
||||
wshape = Shape(node.graph, {'name': node_name + '/WeightsShape'}).create_node()
|
||||
weights_node = node.in_port(1).get_source().node
|
||||
weights_node.out_port(0).connect(wshape.in_port(0))
|
||||
|
||||
weights_shape = Shape(node.graph, {'name': node_name + '/ShapeOfWeights'}).create_node()
|
||||
weights_shape.in_port(0).connect(weights_node.out_port(0))
|
||||
GOI = Const(node.graph, {'value': int64_array([group, node.output // group, -1]),
|
||||
'name': node_name + '/GOI_weights_part'}).create_node()
|
||||
XY = create_op_with_const_inputs(node.graph, Gather,
|
||||
port_value_dict={1: int64_array(list(range(2, len(weights_shape)))), 2: int64_array(0)},
|
||||
op_attrs={'name': node_name + '/XY_weights_part'},
|
||||
input_node=wshape)
|
||||
|
||||
weights_spatial_shape = create_op_with_const_inputs(node.graph, StridedSlice,
|
||||
port_value_dict={1: int64_array([2]),
|
||||
2: int64_array([-1])},
|
||||
op_attrs={'begin_mask': [1],
|
||||
'end_mask': [0],
|
||||
'new_axis_mask': [0],
|
||||
'shrink_axis_mask': [0],
|
||||
'ellipsis_mask': [0]},
|
||||
input_node=weights_shape)
|
||||
|
||||
const_part_of_shape = Const(node.graph, attrs=dict(name=node_name + '/GroupsAndOutputChannelsSize',
|
||||
value=int64_array(
|
||||
[group, node.output // group]))).create_node()
|
||||
|
||||
input_shape_node = Shape(node.graph, {'name': node_name + '/ShapeOfInput'}).create_node()
|
||||
input_shape_node.in_port(0).connect(input_node.out_port(0))
|
||||
|
||||
input_num_channels = create_op_with_const_inputs(node.graph, Gather,
|
||||
port_value_dict={1: int64_array([1]), 2: int64_array(0)},
|
||||
op_attrs={'name': node_name + '/GatherInputNumChannels'},
|
||||
input_node=input_shape_node)
|
||||
|
||||
# input channels num divided by number of groups to alight weights shape into [GROUPS C_OUT C_IN X Y]
|
||||
C_IN = create_op_with_const_inputs(node.graph, Div,
|
||||
port_value_dict={1: int64_array(group)},
|
||||
op_attrs={'name': node_name + '/Div'},
|
||||
input_node=input_num_channels)
|
||||
|
||||
new_shape_node = Concat(node.graph, {'axis': 0, 'in_ports_count': 3}).create_node()
|
||||
new_shape_node.in_port(0).connect(const_part_of_shape.out_port(0))
|
||||
new_shape_node.in_port(1).connect(C_IN.out_port(0))
|
||||
new_shape_node.in_port(2).connect(weights_spatial_shape.out_port(0))
|
||||
reshape = Reshape(node.graph, {'override_output_shape': True, 'special_zero': True}).create_node()
|
||||
reshape.in_port(1).connect(new_shape_node.out_port(0))
|
||||
new_shape_node = Concat(node.graph, {'axis': 0, 'in_ports_count': 2, 'name': node_name + '/weights_shape'}).create_node()
|
||||
new_shape_node.in_port(0).connect(GOI.out_port(0))
|
||||
new_shape_node.in_port(1).connect(XY.out_port(0))
|
||||
reshape = Reshape(node.graph, {'override_output_shape': True, 'special_zero': True}).create_node()
|
||||
reshape.in_port(1).connect(new_shape_node.out_port(0))
|
||||
|
||||
del node['group']
|
||||
node['type'] = 'GroupConvolution'
|
||||
@ -157,7 +125,8 @@ class ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
||||
return [ReshapeMutation, StridedSliceMasksNormalizer]
|
||||
from openvino.tools.mo.back.ShapeOfConstFolding import ShapeOfConstFolding
|
||||
return [ShapeOfConstFolding, ReshapeMutation, StridedSliceMasksNormalizer]
|
||||
|
||||
def run_after(self):
|
||||
return [ApplyReverseChannels]
|
||||
|
@ -31,7 +31,7 @@ def concat_infer(node):
|
||||
axis = get_canonical_axis_index(shape, node.axis)
|
||||
node.axis = axis
|
||||
|
||||
mask = np.zeros_like(shape, dtype=np.bool)
|
||||
mask = np.zeros_like(shape, dtype=bool)
|
||||
mask[axis] = True # pylint: disable=unsupported-assignment-operation
|
||||
not_mask = np.logical_not(mask) # pylint: disable=assignment-from-no-return
|
||||
for s in shapes[1:]:
|
||||
|
@ -5,6 +5,7 @@ import logging as log
|
||||
from copy import copy
|
||||
|
||||
from openvino.tools.mo.back.ConvolutionNormalizer import ConvolutionNormalizer, ConvolutionWithGroupsResolver
|
||||
from openvino.tools.mo.back.ShapeOfConstFolding import ShapeOfConstFolding
|
||||
from openvino.tools.mo.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||
from openvino.tools.mo.back.PackBinaryWeights import PackBinaryWeights
|
||||
from openvino.tools.mo.back.SpecialNodesFinalization import RemoveConstOps, CreateConstNodesReplacement
|
||||
@ -72,6 +73,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None, rename_re
|
||||
# List items order matters, do not change it.
|
||||
transformation_list = [
|
||||
ConvolutionWithGroupsResolver,
|
||||
ShapeOfConstFolding,
|
||||
StridedSliceMasksNormalizer,
|
||||
PackBinaryWeights,
|
||||
BlobNormalizer,
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
|
||||
from openvino.tools.mo.back.ConvolutionNormalizer import PullReshapeThroughFQ, V7ConvolutionWithGroupsResolver, \
|
||||
V10ConvolutionWithGroupsResolver
|
||||
from openvino.tools.mo.back.ShapeOfConstFolding import ShapeOfConstFolding
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, shape_array, dynamic_dimension_value
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
@ -185,14 +186,21 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
|
||||
|
||||
class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def apply_transformation(graph):
|
||||
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
ShapeOfConstFolding().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
def test_v10_group_convolution_resolver(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
|
||||
|
||||
**valued_const_with_data('weights', np.ones([3, 8, 7, 7])),
|
||||
|
||||
**valued_const_with_data('dim', int64_array([3, 8, 1, 7, 7])),
|
||||
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||
**valued_const_with_data('new_weights', np.ones([3, 8, 1, 7, 7])),
|
||||
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 3, 'output': 24}),
|
||||
|
||||
@ -204,16 +212,14 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||
TestV10ConvolutionWithGroupsResolver.apply_transformation(graph)
|
||||
|
||||
nodes['convolution']['type'] = 'GroupConvolution'
|
||||
del nodes['convolution']['group']
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '0:reshape'),
|
||||
*connect('dim', '1:reshape'),
|
||||
*connect('reshape', '1:convolution'),
|
||||
*connect('new_weights', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
@ -226,8 +232,7 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
|
||||
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
|
||||
|
||||
**valued_const_with_data('dim', int64_array([1, 8, 1, 7, 7])),
|
||||
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||
**valued_const_with_data('new_weights', np.ones([1, 8, 1, 7, 7])),
|
||||
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
|
||||
'op': 'DepthwiseConv2dNative'}),
|
||||
@ -240,52 +245,30 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||
TestV10ConvolutionWithGroupsResolver.apply_transformation(graph)
|
||||
|
||||
nodes['convolution']['type'] = 'GroupConvolution'
|
||||
del nodes['convolution']['group']
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '0:reshape'),
|
||||
*connect('dim', '1:reshape'),
|
||||
*connect('reshape', '1:convolution'),
|
||||
*connect('new_weights', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
class TestConvolutionWithGroupsResolverForDynamicWeightsChannels(unittest.TestCase):
|
||||
def test_v10_group_convolution_resolver_for_dynamic_weights(self):
|
||||
num_groups = 2
|
||||
C_OUT = 8
|
||||
|
||||
def test_v10_group_convolution_resolver_depthwise_conv2d_dynamic(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', shape_array([1, dynamic_dimension_value, 224, 224]), {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('input', [-1, -1, -1, -1], {'type': 'Parameter'}),
|
||||
|
||||
**valued_const_with_data('weights', np.ones([num_groups, C_OUT, 7, 7])),
|
||||
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
|
||||
|
||||
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||
**valued_const_with_data('new_weights', np.ones([1, 8, 1, 7, 7])),
|
||||
|
||||
**regular_op_with_empty_data('ss', {'type': 'StridedSlice',
|
||||
'begin_mask': [1], 'end_mask': [0],
|
||||
'new_axis_mask': [0], 'shrink_axis_mask': [0], 'ellipsis_mask': [0]}),
|
||||
|
||||
**regular_op_with_empty_data('weights_shape', {'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('input_shape', {'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('gather', {'type': 'Gather'}),
|
||||
**regular_op_with_empty_data('concat', {'type': 'Concat'}),
|
||||
**regular_op_with_empty_data('div', {'type': 'Divide'}),
|
||||
**valued_const_with_data('channels_const', int64_array([num_groups, C_OUT / num_groups])),
|
||||
**valued_const_with_data('num_groups', int64_array(num_groups)),
|
||||
**valued_const_with_data('begin', int64_array([2])),
|
||||
**valued_const_with_data('end', int64_array([-1])),
|
||||
**valued_const_with_data('channel_index', int64_array([1])),
|
||||
**valued_const_with_data('axis', int64_array(0)),
|
||||
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': num_groups, 'output': C_OUT}),
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
|
||||
'op': 'DepthwiseConv2dNative'}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
@ -295,38 +278,18 @@ class TestConvolutionWithGroupsResolverForDynamicWeightsChannels(unittest.TestCa
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||
TestV10ConvolutionWithGroupsResolver.apply_transformation(graph)
|
||||
|
||||
nodes['convolution']['type'] = 'GroupConvolution'
|
||||
del nodes['convolution']['group']
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '0:reshape'),
|
||||
|
||||
('input_d', 'input_shape', {'in': 0, 'out': 0}),
|
||||
('weights_d', 'weights_shape', {'in': 0, 'out': 0}),
|
||||
|
||||
*connect('input_shape', '0:gather'),
|
||||
*connect('channel_index', '1:gather'),
|
||||
*connect('axis', '2:gather'),
|
||||
|
||||
*connect('weights_shape', '0:ss'),
|
||||
*connect('begin', '1:ss'),
|
||||
*connect('end', '2:ss'),
|
||||
*connect('gather', '0:div'),
|
||||
*connect('num_groups', '1:div'),
|
||||
*connect('channels_const', '0:concat'),
|
||||
*connect('div', '1:concat'),
|
||||
*connect('ss', '2:concat'),
|
||||
*connect('concat', '1:reshape'),
|
||||
|
||||
*connect('reshape', '1:convolution'),
|
||||
*connect('new_weights', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
Const.infer(Node(graph, 'convolution/GroupsAndOutputChannelsSize'))
|
||||
Const.infer(Node(graph, 'convolution/Div_input_port_1/value'))
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output')
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user