[MO] Support MXNet Operations: batch_dot, LayerNorm, div_sqrt_dim (#7641)

* Add new operation support

* Update range_like replacer

* Move layer normalizer to middle

* Update bom file

* Update bom

* removed ArangeLike op

* updated bom

* added tests

* Updated docs

* comments relolving

* resolve documentation merge conflict

* arange_like op

* Revert "arange_like op"

This reverts commit a30f5bbb48.

* fixes in div_sqrt_dim

* comments resolving

* updated tests

* added batch_dot and layer_norm descriptions

* updated batch_dot comment

* updated comment

* move extractors to mxnet folder

* added replacer for batch_dot

* Revert "added replacer for batch_dot"

This reverts commit 8c0e52f7dc.

* return fully connected normalization

* fix typo

* updated fully connected normalization for mxnet

* changed assert message

* fixed gamma and beta shape incompatibility problem

* fixed imports, updated unittest

* resolve comments

Co-authored-by: iimironov <iliya.mironov@intel.com>
This commit is contained in:
Yegor Kruglov 2022-01-18 11:12:30 +03:00 committed by GitHub
parent ff2df42339
commit 9b129b7c1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 484 additions and 3 deletions

View File

@ -56,6 +56,7 @@
| _contrib_box_nms | |
| _contrib_DeformableConvolution | |
| _contrib_DeformablePSROIPooling | |
| _contrib_div_sqrt_dim | |
| _contrib_MultiBoxDetection | "force_suppress" = 1 is not supported, non-default variances are not supported |
| _contrib_MultiBoxPrior | |
| _contrib_Proposal | |
@ -77,6 +78,7 @@
| arccosh | |
| arcsinh | |
| arctanh | |
| batch_dot | |
| broadcast_add | |
| broadcast_div | |
| broadcast_mul | |
@ -94,6 +96,7 @@
| max | |
| minus_scalar | |
| null | Not needed for inference |
| LayerNorm | "output_mean_var" = True is not supported |
| repeat | |
| rnn | |
| rnn_param_concat | |

View File

@ -271,6 +271,7 @@ openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py
openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
openvino/tools/mo/front/mxnet/arange_ext.py
openvino/tools/mo/front/mxnet/arange_replacer.py
openvino/tools/mo/front/mxnet/batch_dot_ext.py
openvino/tools/mo/front/mxnet/block_grad_ext.py
openvino/tools/mo/front/mxnet/box_nms_ext.py
openvino/tools/mo/front/mxnet/cast_ext.py
@ -285,6 +286,7 @@ openvino/tools/mo/front/mxnet/custom.py
openvino/tools/mo/front/mxnet/custom_rpn_proposal.py
openvino/tools/mo/front/mxnet/deformable_conv_ext.py
openvino/tools/mo/front/mxnet/deformable_psroi_pooling_ext.py
openvino/tools/mo/front/mxnet/div_sqrt_dim.py
openvino/tools/mo/front/mxnet/dropout_ext.py
openvino/tools/mo/front/mxnet/einsum_ext.py
openvino/tools/mo/front/mxnet/elementwise_ext.py
@ -309,6 +311,7 @@ openvino/tools/mo/front/mxnet/gather.py
openvino/tools/mo/front/mxnet/gather_ext.py
openvino/tools/mo/front/mxnet/gluoncv_ssd_anchors.py
openvino/tools/mo/front/mxnet/instance_norm_ext.py
openvino/tools/mo/front/mxnet/layer_norm_ext.py
openvino/tools/mo/front/mxnet/leaky_relu.py
openvino/tools/mo/front/mxnet/loader.py
openvino/tools/mo/front/mxnet/lrn_ext.py
@ -750,6 +753,7 @@ openvino/tools/mo/middle/InsertLayoutPropagationTransposes.py
openvino/tools/mo/middle/InsertSelect.py
openvino/tools/mo/middle/InterpolateSequenceToInterpolate.py
openvino/tools/mo/middle/L2NormFusing.py
openvino/tools/mo/middle/layer_normalization.py
openvino/tools/mo/middle/LayoutChangeForConstantShapePaths.py
openvino/tools/mo/middle/LayoutChangeForEinsum.py
openvino/tools/mo/middle/LeakyReluPattern.py
@ -877,6 +881,7 @@ openvino/tools/mo/ops/dequantize_linear.py
openvino/tools/mo/ops/detection_output_onnx.py
openvino/tools/mo/ops/DetectionOutput.py
openvino/tools/mo/ops/dft.py
openvino/tools/mo/ops/div_sqrt_dim.py
openvino/tools/mo/ops/dropoutmask.py
openvino/tools/mo/ops/einsum.py
openvino/tools/mo/ops/elementwise.py
@ -907,6 +912,7 @@ openvino/tools/mo/ops/If.py
openvino/tools/mo/ops/instance_normalization.py
openvino/tools/mo/ops/interp.py
openvino/tools/mo/ops/interpolate.py
openvino/tools/mo/ops/layer_norm.py
openvino/tools/mo/ops/log_softmax.py
openvino/tools/mo/ops/LookupTableInsert.py
openvino/tools/mo/ops/loop.py

View File

@ -55,10 +55,17 @@ class FullyConnectedDecomposer(FrontReplacementSubgraph):
node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose,
new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0]))
# input normalization for 4D Caffe and MxNet FullyConnected
if graph.graph['fw'] in ['caffe', 'mxnet']:
# input normalization for 4D Caffe and MXNet FullyConnected
if graph.graph['fw'] == 'caffe':
node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1]))
new_op_attrs={'name': name + '/flatten_fc_input', 'special_zero': True},
value=int64_array([0, -1]))
if graph.graph['fw'] == 'mxnet':
if node.flatten is not False:
node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
new_op_attrs={'name': name + '/flatten_fc_input', 'special_zero': True},
value=int64_array([0, -1]))
MatMul.update_node_stat(node, {})

View File

@ -0,0 +1,39 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
from openvino.tools.mo.front.extractor import FrontExtractorOp
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.MatMul import MatMul
class BatchDotExt(FrontExtractorOp):
"""
MXNet operation which computes batch matrix multiplication of x and y similar to TensorFlow or ONNX MatMul operation.
Attributes:
transpose_a - if true then transpose the first input before multiplication
transpose_b - if true then transpose the second input before multiplication
"""
op = 'batch_dot'
enabled = True
@classmethod
def extract(cls, node: Node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
transpose_a = attrs.bool('transpose_a', False)
transpose_b = attrs.bool('transpose_b', False)
forward_stype = attrs.str('forward_stype', None)
if forward_stype is not None:
log.error("Node {} has non default value {} of attribute forward_stype."
"Model Optimizer conversion assumes default value = None".format(node.soft_get('name', node.id),
forward_stype),
extra={'is_warning': True})
MatMul.update_node_stat(node, {
'transpose_a': transpose_a,
'transpose_b': transpose_b
})
return cls.enabled

View File

@ -0,0 +1,53 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino.tools.mo.front.PowerToEltwises import PowerToEltwises
from openvino.tools.mo.front.common.partial_infer.utils import mo_array
from openvino.tools.mo.front.common.replacement import FrontReplacementOp
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.ops.Cast import Cast
from openvino.tools.mo.ops.ConvertLike import ConvertLike
from openvino.tools.mo.ops.elementwise import Div
from openvino.tools.mo.ops.power import AttributedPower
from openvino.tools.mo.ops.shape import Shape
from openvino.tools.mo.utils.shape import node_to_get_shape_value_of_indices
class DivSqrtDim(FrontReplacementOp):
"""
Replace _contrib_div_sqrt_dim with sub-graph that matches the formula out = (data / sqrt(data.shape[-1]))
"""
op = '_contrib_div_sqrt_dim'
enabled = True
def run_before(self):
return [PowerToEltwises]
def replace_sub_graph(self, graph: Graph, match: dict):
div_sqrt = match['op']
div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id)
shape_node = Shape(graph, dict(name=div_sqrt_name + '/Shape')).create_node()
data_out_port = div_sqrt.in_port(0).get_source()
shape_node.in_port(0).connect(data_out_port)
shape_values_node = node_to_get_shape_value_of_indices(shape_node=shape_node, indices=[-1])
pow_node = AttributedPower(graph, dict(name=div_sqrt_name + '/Sqrt',
power=mo_array(0.5))).create_node()
# Due to specification, Power must have inputs with the same data type.
convert_pow_input = Cast(graph, dict(dst_type=np.float32,
name=shape_values_node.name + '/ConvertToFP32')).create_node()
convert_pow_output = ConvertLike(graph, dict(name=pow_node.name + 'ConvertLike')).create_node()
div_node = Div(graph, dict(name="Div")).create_node()
shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
convert_pow_input.out_port(0).connect(pow_node.in_port(0))
div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0))
pow_node.out_port(0).connect(convert_pow_output.in_port(0))
convert_pow_output.in_port(1).connect(data_out_port)
div_node.in_port(1).connect(convert_pow_output.out_port(0))
div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))
rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])

View File

@ -18,6 +18,7 @@ class FullyConnectedFrontExtractor(FrontExtractorOp):
attrs = {
'out-size': num_hidden,
'transpose_weights': True,
'flatten': attr.bool('flatten', True)
}
FullyConnected.update_node_stat(node, attrs)
return cls.enabled

View File

@ -0,0 +1,23 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.extractor import FrontExtractorOp
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.layer_norm import LayerNorm
class LayerNormFrontExtractor(FrontExtractorOp):
op = 'LayerNorm'
enabled = True
@classmethod
def extract(cls, node: Node):
attr = get_mxnet_layer_attrs(node.symbol_dict)
node_attrs = {
'epsilon': attr.float('eps', 9.99999975e-06),
'axis': attr.int('axis', -1),
'output_mean_var': attr.bool('output_mean_var', False)
}
LayerNorm.update_node_stat(node, node_attrs)
return cls.enabled

View File

@ -0,0 +1,75 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
from openvino.tools.mo.front.caffe.extractors.utils import get_canonical_axis_index
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.tf.graph_utils import create_op_node_with_second_input
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
from openvino.tools.mo.ops.elementwise import Mul, Add
from openvino.tools.mo.ops.mvn import MVN
from openvino.tools.mo.ops.unsqueeze import Unsqueeze
from openvino.tools.mo.utils.error import Error
class LayerNormalization(MiddleReplacementPattern):
"""
Decompose LayerNorm(x) to MVN(x) * gamma + beta
LayerNorm is supported with only 1 output.
"""
enabled = True
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='LayerNorm'):
node_name = node.soft_get('name', node.id)
if node.output_mean_var is True:
if not node.out_port(1).disconnected() or not node.out_port(2).disconnected():
raise Error("Node {} is supported with only one output".format(node_name))
log.error('LayerNorm node {} with attribute "output_mean_var" = True is not supported.'
'But since the node has one output, the conversion will continue.'.format(node_name),
extra={'is_warning': True})
input_shape = node.in_port(0).data.get_shape()
assert node.has_valid('axis'), 'Incorrect axis value for the node {}'.format(node_name)
axis = node.axis
mvn = create_op_node_with_second_input(graph, MVN, int64_array([axis]),
dict(eps=node.epsilon, name=node_name + '/LayerNorm/MVN_',
across_channels=1, normalize_variance=1, eps_mode='inside_sqrt'))
mul = Mul(graph, {'name': node_name + '/LayerNorm/mul_'}).create_node()
add = Add(graph, {'name': mul.name + '/LayerNorm/add_'}).create_node()
node.in_port(0).get_connection().set_destination(mvn.in_port(0))
node.in_port(1).get_connection().set_destination(mul.in_port(1))
node.in_port(2).get_connection().set_destination(add.in_port(1))
mvn.out_port(0).connect(mul.in_port(0))
mul.out_port(0).connect(add.in_port(0))
node.out_port(0).get_connection().set_source(add.out_port(0))
# MXNet LayerNorm gamma and beta attributes are 1D tensors with shape = [input_shape[axis]]
# We have to unsqueeze values for Mul and Add operations to avoid shapes incompatibility problems
# if axis != -1
canonical_axis = get_canonical_axis_index(input_shape, axis)
unsqueeze_value = []
for idx, val in enumerate(input_shape):
if idx != canonical_axis:
unsqueeze_value.append(idx)
mul_const_unsqueeze = create_op_node_with_second_input(graph, Unsqueeze,
int64_array(unsqueeze_value),
dict(name=mul.name + '/Unsqueeze',
override_output_shape=True))
add_const_unsqueeze = create_op_node_with_second_input(graph, Unsqueeze,
int64_array(unsqueeze_value),
dict(name=add.name + '/Unsqueeze',
override_output_shape=True))
mul.in_port(1).get_connection().insert_node(mul_const_unsqueeze)
add.in_port(1).get_connection().insert_node(add_const_unsqueeze)
rename_nodes([(node, node_name + '/ShouldBeDeleted'), (add, node_name)])

View File

@ -0,0 +1,22 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.ops.op import Op
class DivSqrtDimOp(Op):
"""
MXNet operation that matches the formula out = (data / sqrt(data.shape[-1])).
Will be replaced with the corresponding sub-graph
"""
op = '_contrib_div_sqrt_dim'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': None,
'op': self.op,
'infer': None,
'in_ports_count': 1,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)

View File

@ -0,0 +1,40 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.common.partial_infer.elemental import copy_shape_infer
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.ops.op import Op
class LayerNorm(Op):
"""
MXNet operation which normalizes the channels of the input tensor by mean and variance, and applies a scale gamma
and offset beta. Operation computes output with the same shape as input as following formula:
out = ((data - mean(data, axis)) / sqrt(var(data, axis) + eps)) * gamma + beta
inputs:
data - input data
gamma - gamma array
beta - beta array
attributes:
axis - axis to perform layer normalization
eps - epsilon parameter to prevent division by zero
output_mean_var - output the mean and std calculated along the given axis. Default value is False. Non default value
is not supported
"""
op = 'LayerNorm'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'type': None,
'axis': -1,
'epsilon': 0.001,
'output_mean_var': False,
'infer': copy_shape_infer,
'in_ports_count': 3 if attrs.get('output_mean_var') is True else 1,
'out_ports_count': 3,
}, attrs)

View File

@ -0,0 +1,59 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.mxnet.div_sqrt_dim import DivSqrtDim
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, shaped_parameter, regular_op_with_empty_data, result, connect, \
shaped_const_with_data, connect_data, connect_front
class DivSqrtDimTest(unittest.TestCase):
def test_1(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**regular_op_with_empty_data('div_sqrt_dim', {'op': '_contrib_div_sqrt_dim'}),
**result('result')
},
edges=[
*connect('input', 'div_sqrt_dim'),
*connect('div_sqrt_dim', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**regular_op_with_empty_data('div_sqrt_shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
**shaped_const_with_data('gather_axis', None),
**shaped_const_with_data('gather_indices', None),
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
**regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}),
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}),
**regular_op_with_empty_data('z_convert_like', {'op': 'ConvertLike', 'type': 'ConvertLike'}),
**regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}),
**result('result')
},
edges=[
*connect('input', '0:div'),
*connect_data('input', 'div_sqrt_shape_of'),
*connect('div_sqrt_shape_of', '0:gather'),
*connect('gather_axis', '1:gather'),
*connect('gather_indices', '2:gather'),
*connect('gather', 'cast'),
*connect('cast', 'power'),
*connect('power', '0:z_convert_like'),
*connect_front('input_d', '1:z_convert_like'),
*connect('z_convert_like', '1:div'),
*connect('div', 'result')
],
)
DivSqrtDim().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -0,0 +1,153 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.middle.layer_normalization import LayerNormalization
from openvino.tools.mo.utils.error import Error
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, shaped_parameter, regular_op_with_empty_data, shaped_const_with_data, \
result, connect
class LayerNormalizationTest(unittest.TestCase):
def test_1(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**regular_op_with_empty_data('layer_norm', {'op': 'LayerNorm', 'epsilon': 0.001, 'axis': -1,
'output_mean_var': False}),
**shaped_const_with_data('gamma', None),
**shaped_const_with_data('beta', None),
**result('result')
},
edges=[
*connect('input', '0:layer_norm'),
*connect('gamma', '1:layer_norm'),
*connect('beta', '2:layer_norm'),
*connect('layer_norm', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**shaped_const_with_data('mvn_const', None),
**regular_op_with_empty_data('mvn', {'eps': 0.001, 'across_channels': 1, 'normalize_variance': 1,
'eps_mode': 'inside_sqrt', 'op': 'MVN', 'type': 'MVN'}),
**shaped_const_with_data('gamma', None),
**regular_op_with_empty_data('gamma_unsqueeze', {'op': 'Unsqueeze', 'type': 'Unsqueeze'}),
**shaped_const_with_data('gamma_unsqueeze_const', None),
**regular_op_with_empty_data('beta_unsqueeze', {'op': 'Unsqueeze', 'type': 'Unsqueeze'}),
**shaped_const_with_data('beta_unsqueeze_const', None),
**regular_op_with_empty_data('mul', {'op': 'Mul', 'type': 'Multiply'}),
**shaped_const_with_data('beta', None),
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
**result('result')
},
edges=[
*connect('input', '0:mvn'),
*connect('mvn_const', '1:mvn'),
*connect('mvn', '0:mul'),
*connect('gamma', 'gamma_unsqueeze'),
*connect('gamma_unsqueeze_const', '1:gamma_unsqueeze'),
*connect('gamma_unsqueeze', '1:mul'),
*connect('mul', '0:add'),
*connect('beta', 'beta_unsqueeze'),
*connect('beta_unsqueeze_const', '1:beta_unsqueeze'),
*connect('beta_unsqueeze', '1:add'),
*connect('add', 'result')
],
update_attributes={
'mvn_const': {'value': int64_array([-1]), 'shape': int64_array([1])},
'gamma_unsqueeze_const': {'value': int64_array([0, 1, 2]), 'shape': int64_array([3])},
'beta_unsqueeze_const': {'value': int64_array([0, 1, 2]), 'shape': int64_array([3])}
}
)
LayerNormalization().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_2(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**regular_op_with_empty_data('layer_norm', {'op': 'LayerNorm', 'epsilon': 0.001, 'axis': 1,
'output_mean_var': False}),
**shaped_const_with_data('gamma', None),
**shaped_const_with_data('beta', None),
**result('result')
},
edges=[
*connect('input', '0:layer_norm'),
*connect('gamma', '1:layer_norm'),
*connect('beta', '2:layer_norm'),
*connect('layer_norm', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**shaped_const_with_data('mvn_const', None),
**regular_op_with_empty_data('mvn', {'eps': 0.001, 'across_channels': 1, 'normalize_variance': 1,
'eps_mode': 'inside_sqrt', 'op': 'MVN', 'type': 'MVN'}),
**shaped_const_with_data('gamma', None),
**regular_op_with_empty_data('gamma_unsqueeze', {'op': 'Unsqueeze', 'type': 'Unsqueeze'}),
**shaped_const_with_data('gamma_unsqueeze_const', None),
**regular_op_with_empty_data('beta_unsqueeze', {'op': 'Unsqueeze', 'type': 'Unsqueeze'}),
**shaped_const_with_data('beta_unsqueeze_const', None),
**regular_op_with_empty_data('mul', {'op': 'Mul', 'type': 'Multiply'}),
**shaped_const_with_data('beta', None),
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
**result('result')
},
edges=[
*connect('input', '0:mvn'),
*connect('mvn_const', '1:mvn'),
*connect('mvn', '0:mul'),
*connect('gamma', 'gamma_unsqueeze'),
*connect('gamma_unsqueeze_const', '1:gamma_unsqueeze'),
*connect('gamma_unsqueeze', '1:mul'),
*connect('mul', '0:add'),
*connect('beta', 'beta_unsqueeze'),
*connect('beta_unsqueeze_const', '1:beta_unsqueeze'),
*connect('beta_unsqueeze', '1:add'),
*connect('add', 'result')
],
update_attributes={
'mvn_const': {'value': int64_array([1]), 'shape': int64_array([1])},
'gamma_unsqueeze_const': {'value': int64_array([0, 2, 3]), 'shape': int64_array([3])},
'beta_unsqueeze_const': {'value': int64_array([0, 2, 3]), 'shape': int64_array([3])}
}
)
LayerNormalization().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_negative(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 15, 15])),
**regular_op_with_empty_data('layer_norm', {'op': 'LayerNorm', 'epsilon': 0.001, 'axis': -1,
'output_mean_var': True}),
**shaped_const_with_data('gamma', None),
**shaped_const_with_data('beta', None),
**result('result'),
**result('result_1'),
**result('result_2')
},
edges=[
*connect('input', '0:layer_norm'),
*connect('gamma', '1:layer_norm'),
*connect('beta', '2:layer_norm'),
*connect('layer_norm:0', 'result'),
*connect('layer_norm:1', 'result_1'),
*connect('layer_norm:2', 'result_2')
]
)
with self.assertRaises(Error):
LayerNormalization().find_and_replace_pattern(graph)