LayerNorm(PyTorch/HuggingFace pattern)->MVN+Mul+Add (#1003)

* LayerNorm(PyTorch/HuggingFace pattern)->MVN+Mul+Add. Improves perf on BERT by 5%

* deducing the across_channels from axes passed to the MVN op.
axes are normalized. if no axes is specified, falling back to the (previously) default across_channel value

Co-authored-by: myshevts <maim.y.shevtsov@intel.com>
This commit is contained in:
Maxim Shevtsov 2020-06-25 09:25:56 +03:00 committed by GitHub
parent f81257c969
commit 7e40136c3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 188 additions and 15 deletions

View File

@ -144,6 +144,7 @@ extensions/front/kaldi/set_ports.py
extensions/front/kaldi/sigmoid_ext.py
extensions/front/kaldi/split_memoryoffsets.py
extensions/front/kaldi/tanh_component_ext.py
extensions/front/LayerNorm.py
extensions/front/Log1p.py
extensions/front/LogSoftmax.py
extensions/front/MatMul_normalizer.py

View File

@ -0,0 +1,75 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph, rename_nodes
from extensions.ops.mvn import MVN
class LayerNorm(FrontReplacementSubgraph):
# Compose part of the LayerNorm pattern to the MVN
enabled = True
def pattern(self):
log.info('Enabled LayerNorm pattern recognition')
return dict(
nodes=[
('pool0', dict(op='ReduceMean')),
('pool1', dict(op='ReduceMean')),
('pow', dict(op='Pow')),
('div', dict(op='Div')),
('sqrt', dict(op='Pow')),
('add', dict(op='Add')),
('sub', dict(op='Sub')),
('pool0_param', dict(op='Const')),
('pool1_param', dict(op='Const')),
('add_param', dict(op='Const')),
('pow_param', dict(op='Const')),
],
edges=[
('pool0', 'sub'),
('sub', 'pow'),
('pow', 'pool1'),
('pool1', 'add'),
('add', 'sqrt'),
('sqrt', 'div'),
('sub', 'div'),
('pool0_param', 'pool0'),
('pool1_param', 'pool1'),
('pow_param', 'sqrt'),
('add_param', 'add'),
])
def replace_sub_graph(self, graph: Graph, match: dict):
inp = match['pool0']
inp_port = inp.in_port(0).get_source()
# take/check the values of the add, pow and axes for ReduceMean
pow_param = match['pow_param']
add_param = match['add_param']
if add_param.value.size == 1 and pow_param.value.size == 1 and add_param.value.item() <= 1e-05 \
and pow_param.value.item() == 0.5 and match['pool0_param'].value == match['pool1_param'].value:
log.debug('Found LayerNorm pattern after {} with name {}'.format(inp_port.node.op, inp_port.node.name))
mvn = MVN(graph, {'eps': add_param.value.item(),
'axes': match['pool1_param'].value,
'normalize_variance': 1}).create_node()
div_name = match['div'].soft_get('name', match['div'].id)
rename_nodes([(match['div'], div_name + '/to_be_removed'), (mvn, div_name)])
inp_port.connect(mvn.in_port(0))
match['div'].out_port(0).get_connection().set_source(mvn.out_port(0))

View File

@ -0,0 +1,83 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from math import sqrt
from extensions.front.LayerNorm import LayerNorm
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
nodes_attributes_mvn = {
'inp': {'kind': 'op', 'op': 'AnyOp'},
'pool0': {'kind': 'op', 'op': 'ReduceMean'},
'pool1': {'kind': 'op', 'op': 'ReduceMean'},
'pow': {'kind': 'op', 'op': 'Pow'},
'div': {'kind': 'op', 'op': 'Div'},
'sqrt': {'kind': 'op', 'op': 'Pow'},
'add': {'kind': 'op', 'op': 'Add'},
'sub': {'kind': 'op', 'op': 'Sub'},
'add_param': {'kind': 'op', 'op': 'Const'},
'pow_param': {'kind': 'op', 'op': 'Const'},
'pool0_param': {'kind': 'op', 'op': 'Const'},
'pool1_param': {'kind': 'op', 'op': 'Const'},
'out': {'kind': 'op', 'op': 'AnyOp'},
}
nodes_attributes_ref = {
'inp': {'kind': 'op', 'op': 'AnyOp'},
'mvn': {'kind': 'op', 'op': 'MVN'},
'out': {'kind': 'op', 'op': 'AnyOp'},
}
class TestMVNPatternReplacement(unittest.TestCase):
def test_MVNPatternReplacement_test_1(self):
graph = build_graph(nodes_attributes_mvn,
[('inp', 'pool0', {'out': 0}),
('inp', 'sub', {'out': 0}),
('pool0', 'sub'),
('sub', 'pow'),
('pow', 'pool1'),
('pool1', 'add'),
('add', 'sqrt'),
('sqrt', 'div'),
('sub', 'div'),
('div', 'out'),
('pow_param', 'sqrt'),
('add_param', 'add'),
('pool0_param', 'pool0'),
('pool1_param', 'pool1'),
],
{'pow_param': {'shape': np.array([1]), 'value': np.array(0.5)},
'add_param': {'shape': np.array([1]), 'value': np.array(1e-06)},
'pool0_param': {'shape': np.array([1]), 'value': np.array(-1)},
'pool1_param': {'shape': np.array([1]), 'value': np.array(-1)},
},
nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes_ref,
[('inp', 'mvn'),
('mvn', 'out')],
{}, nodes_with_edges_only=True)
graph.stage = 'front'
replacer = LayerNorm()
replacer.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'out', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -13,8 +13,9 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.layout import get_features_dim
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.graph.graph import Graph
from mo.ops.op import Op
from mo.utils.error import Error
@ -31,7 +32,7 @@ class MVN(Op):
'op': __class__.op,
'version': 'opset2',
'eps': None,
'across_channels': 0,
'across_channels': None,
'normalize_variance': 1,
'axes': None,
'in_ports_count': 1,
@ -49,18 +50,31 @@ class MVN(Op):
def infer(node: None):
input_shape = node.in_node(0).shape
name = node.soft_get('name', node.id)
axes = node.axes
if axes is not None:
if 0 in axes:
raise Error('Reduction over the batch dimension in node "{}" '
'is not supported by the backend.'.format(name))
for i in range(2, len(input_shape)):
if i not in axes:
raise Error(
'Reduction over spatial dimensions in node "{}" '
'is obligatory for the backend.'.format(name))
if 1 in axes and not node.across_channels:
raise Error('Inconsistent values of axes ({}) and across_channels ({}) parameters '
'in node "{}".'.format(str(axes), str(node.across_channels), name))
if node.axes is not None and node.across_channels is not None:
raise Error('Either axes or across_channels can be set for the MVN in node "{}".'.format(name))
if node.across_channels is None:
if node.axes is not None:
# normalizing (replacing -1 with actual index)
axes_data_value = node.axes
axes = [axes_data_value.item()] if axes_data_value.size == 1 else axes_data_value
axes = [get_canonical_axis_index(input_shape, a) for a in axes]
# deduce across_channels from the axes, e.g. if the first axis is included (assuming batch is zero axis)
feature_dim = get_features_dim(node.graph.graph['layout'], len(input_shape)) \
if (4 <= len(input_shape) <= 5) \
else 1
node.across_channels = int(feature_dim in axes)
if 0 in axes:
raise Error('Reduction over the batch dimension in node "{}" '
'is not supported by the backend.'.format(name))
for i in range(2, len(input_shape)):
if i not in axes:
raise Error(
'Reduction over spatial dimensions in node "{}" '
'is obligatory for the backend.'.format(name))
else:
node.across_channels = 0 # default
copy_shape_infer(node)