Files
openvino/model-optimizer/extensions/back/ConvolutionNormalizer.py
Anton Chetverikov 56916ace61 Fix const node non-deterministic names (part 2) (#1081)
* Fix non-deterministic node names generation in the Model Optimizer (part 2)
2020-07-07 09:37:48 +03:00

287 lines
10 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.back.ReshapeMutation import ReshapeMutation
from extensions.back.ReverseInputChannels import ApplyReverseChannels
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.reshape import Reshape
from mo.ops.strided_slice import StridedSlice
class ConvolutionNormalizer(BackReplacementPattern):
enabled = True
def pattern(self):
return dict(
nodes=[
('node', dict(kind='op', type='Convolution'))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
if node.has_valid('kernel_spatial'):
del node['kernel_spatial']
class V7ConvolutionWithGroupsResolver(BackReplacementPattern):
"""
Normalizes grouped convolution weights shape to fit special weights format [G*O I X Y]
"""
enabled = False
@staticmethod
def pattern():
return dict(
nodes=[
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
group = node.group
assert group > 1
weights_shape = node.in_port(1).data.get_shape()
assert weights_shape is not None
assert weights_shape[0] % group == 0
if weights_shape[0] == node.output:
# weights are already is in [G*O I X Y] format
return
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
{'override_output_shape': True})
node.in_port(1).get_connection().insert_node(reshape)
class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
"""
Normalizes grouped convolution weights shape to fit special weights format
V10 IR: [G O I X Y]
"""
enabled = False
@staticmethod
def pattern():
return dict(
nodes=[
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
group = node.group
assert group > 1
weights_shape = node.in_port(1).data.get_shape()
assert weights_shape is not None
assert weights_shape[0] % group == 0
I = node.in_port(0).data.get_shape()[1]
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
assert np.prod(weights_shape) == np.prod(new_shape), \
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
del node['group']
node['type'] = 'GroupConvolution'
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
{'override_output_shape': True})
node.in_port(1).get_connection().insert_node(reshape)
class ConvolutionWithGroupsResolver(BackReplacementPattern):
"""
Normalizes grouped convolution weights shape to fit special weights format
V10 IR: [G O I X Y]
lower IR versions: [G*O I X Y]
"""
enabled = True
force_clean_up = True
def run_before(self):
return [ReshapeMutation]
def run_after(self):
return [ApplyReverseChannels]
def find_and_replace_pattern(self, graph: Graph):
V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
PullReshapeThroughFQ().find_and_replace_pattern(graph)
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
class PullReshapeThroughFQ(BackReplacementPattern):
"""
Before:
... -> FQ -> Reshape -> Convolution -> ...
After:
... -> Reshape -> FQ (with aligned limits) -> Convolution -> ...
"""
enabled = False
@staticmethod
def pattern():
return dict(
nodes=[
('FQ', dict(type='FakeQuantize')),
('FQed', dict()),
('reshape', dict(type='Reshape')),
('reshaped', dict()),
('node', dict(type=lambda t: t in ['Convolution', 'GroupConvolution'])),
],
edges=[
('FQ', 'FQed'),
('FQed', 'reshape', {'in': 0}),
('reshape', 'reshaped'),
('reshaped', 'node', {'in': 1}),
]
)
def replace_pattern(self, graph: Graph, match: dict):
FQ = match['FQ']
reshape = match['reshape']
conv = match['node']
rank_reshape = reshape.in_port(0).data.get_shape().size != reshape.out_port(0).data.get_shape().size
if not all([np.prod(FQ.in_port(i).data.get_shape()) == 1 for i in range(1, 5)]):
# FakeQuantize has limits with multiple values, that should be reshaped too
# Pulling Reshape through such FQ is a complex procedure because of broadcasting rules
return
new_rank = reshape.out_port(0).data.get_shape().size
reshape.in_port(0).disconnect()
reshape.out_port(0).disconnect()
FQ.out_port(0).connect(conv.in_port(1))
FQ.in_port(0).get_connection().insert_node(reshape)
reshape['need_shape_inference'] = True
reshape['override_output_shape'] = True
FQ['need_shape_inference'] = True
FQ['override_output_shape'] = True
if rank_reshape:
# force rank of limit inputs to match 0-input rank
# reshaping to lower range needs it the most due to FQ inner broadcast semantics
for i in range(1, 5):
reshape = create_op_node_with_second_input(graph, Reshape, int64_array([1] * new_rank),
{'override_output_shape': True})
FQ.in_port(i).get_connection().insert_node(reshape)
class DeconvolutionNormalizer(BackReplacementPattern):
enabled = True
force_clean_up = True
def run_before(self):
return [ReshapeMutation]
def run_after(self):
return [ApplyReverseChannels]
@staticmethod
def pattern():
return dict(
nodes=[
('node', dict(type='Deconvolution'))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
node_name = node.soft_get('name', node.id)
if 2 in node.in_ports() and not node.in_port(2).disconnected():
# Third input represents output shape. Cutting its value according to scheme:
# [N, C, spatial_dim_0, ..., spatial_dim_n] -> [spatial_dim_0, ..., spatial_dim_n]
in_rank = node.in_port(0).data.get_shape().size
shape_src = node.in_port(2).get_source()
node.in_port(2).disconnect()
ss_0 = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([2], dtype=np.int32),
2: np.array([in_rank], dtype=np.int32),
3: np.array([1], dtype=np.int32)},
{'name': node_name + '/ss_0_port',
'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32),
'shrink_axis_mask': np.array([0], dtype=np.int32),
'ellipsis_mask': np.array([0], dtype=np.int32)})
shape_src.connect(ss_0.in_port(0))
ss_0.out_port(0).connect(node.in_port(2))
# Specification: *padding amount* is deduced from relation of input and output spatial shapes
del node['pad']
elif node.has_valid('original_output_spatial_shape'):
# node had fixed output spatial shape set in original framework, so we restore it here
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape),
'name': node_name + '/original_spatial_shape'}).create_node()
node.add_input_port(2, skip_if_exist=True)
const.out_port(0).connect(node.in_port(2))
# Specification: *padding amount* is deduced from relation of input and output spatial shapes
del node['pad']
group = node.soft_get('group', 1)
if group != 1:
assert group > 1
weights_shape = node.in_port(1).data.get_shape()
assert weights_shape is not None
I = node.in_port(0).data.get_shape()[1]
assert I % group == 0
assert node.output % group == 0
new_shape = int64_array([group, I / group, node.output / group, *weights_shape[2:]])
assert np.prod(weights_shape) == np.prod(new_shape), \
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
{'override_output_shape': True},
node.in_port(1).get_source().node)
node.in_port(1).get_connection().set_source(reshape.out_port(0))
node['type'] = 'GroupConvolutionBackpropData'
else:
node['type'] = 'ConvolutionBackpropData'