Files
openvino/model-optimizer/extensions/back/GroupedConvWeightsNormalize.py
Alexey Suhov 6478f1742a Align copyright notice in python scripts (CVS-51320) (#4974)
* Align copyright notice in python scripts (CVS-51320)
2021-03-26 17:54:28 +03:00

39 lines
1.5 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph
from mo.ops.const import Const
class GroupedConvWeightsNormalize(BackReplacementPattern):
"""
This pass is a workaround for nGraph GroupedConvolution operation
It requires that weights layout will be next: G*O*I,1,H,W
"""
enabled = True
force_clean_up = True
def pattern(self):
return dict(
nodes=[
('conv', {'type': 'Convolution', 'group': lambda x: x != 1}),
('weights', {'type': 'Const', 'kind': 'op'}),
('weights_data', {'kind': 'data'}),
],
edges=[('weights', 'weights_data'), ('weights_data', 'conv')]
)
def replace_pattern(self, graph: Graph, match: dict):
conv = match['conv']
weights = match['weights']
input_shape = conv.in_port(0).data.get_shape()
new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape),
'name': weights.soft_get('name', weights.id) + '_new'}).create_node()
weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
new_weights.infer(new_weights)