Extend MO for operation Einsum-7 (#5401)
* Extend MO for operation Einsum-7 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Add extractor for einsum and optimize code based on review feedback Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix the code based on the review: correct code, tests and comments; move insert_transpose Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix LayoutChangeForEinsum transformation condition Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Update third-party dependencies Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
9db7f849df
commit
dc22c177d5
@ -29,6 +29,7 @@ extensions/back/GroupedConvWeightsNormalize.py
|
||||
extensions/back/insert_compatibility_l2normalization.py
|
||||
extensions/back/InterpolateReshape.py
|
||||
extensions/back/kaldi_remove_memory_output.py
|
||||
extensions/back/LayoutChangeForEinsum.py
|
||||
extensions/back/LayoutChangeForGatherND.py
|
||||
extensions/back/LeakyReLUMutation.py
|
||||
extensions/back/LinearToLinearONNXReplacer.py
|
||||
@ -193,6 +194,7 @@ extensions/front/mxnet/custom_rpn_proposal.py
|
||||
extensions/front/mxnet/deformable_conv_ext.py
|
||||
extensions/front/mxnet/deformable_psroi_pooling_ext.py
|
||||
extensions/front/mxnet/dropout_ext.py
|
||||
extensions/front/mxnet/einsum_ext.py
|
||||
extensions/front/mxnet/elementwise_ext.py
|
||||
extensions/front/mxnet/eltwise_scalar_replacers.py
|
||||
extensions/front/mxnet/exp_ext.py
|
||||
@ -278,6 +280,7 @@ extensions/front/onnx/dequantize_linear_resolver.py
|
||||
extensions/front/onnx/detection_output.py
|
||||
extensions/front/onnx/detectionoutput_ext.py
|
||||
extensions/front/onnx/dropout_ext.py
|
||||
extensions/front/onnx/einsum_ext.py
|
||||
extensions/front/onnx/elementwise_ext.py
|
||||
extensions/front/onnx/expand_ext.py
|
||||
extensions/front/onnx/faster_rcnn.json
|
||||
@ -404,6 +407,7 @@ extensions/front/tf/deconv_ext.py
|
||||
extensions/front/tf/depth_to_space.py
|
||||
extensions/front/tf/efficient_det_support_api_v2.0.json
|
||||
extensions/front/tf/efficient_det_support_api_v2.4.json
|
||||
extensions/front/tf/einsum_ext.py
|
||||
extensions/front/tf/elementwise_ext.py
|
||||
extensions/front/tf/embedding_segments_sum.py
|
||||
extensions/front/tf/expand_dims_ext.py
|
||||
@ -678,6 +682,7 @@ extensions/ops/dequantize_linear.py
|
||||
extensions/ops/DetectionOutput.py
|
||||
extensions/ops/detectionoutput_onnx.py
|
||||
extensions/ops/dft.py
|
||||
extensions/ops/einsum.py
|
||||
extensions/ops/elementwise.py
|
||||
extensions/ops/embedding_bag.py
|
||||
extensions/ops/Enter.py
|
||||
@ -1028,6 +1033,7 @@ mo/utils/ir_reader/extenders/convert_extender.py
|
||||
mo/utils/ir_reader/extenders/ctc_greedy_decoder_seq_len_extender.py
|
||||
mo/utils/ir_reader/extenders/deconvolution_extender.py
|
||||
mo/utils/ir_reader/extenders/deformable_convolution_extender.py
|
||||
mo/utils/ir_reader/extenders/einsum_extender.py
|
||||
mo/utils/ir_reader/extenders/experimental_extender.py
|
||||
mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
|
||||
mo/utils/ir_reader/extenders/fakequantize_extender.py
|
||||
|
57
model-optimizer/extensions/back/LayoutChangeForEinsum.py
Normal file
57
model-optimizer/extensions/back/LayoutChangeForEinsum.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.ops.einsum import Einsum
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
|
||||
class LayoutChangeForEinsum(BackReplacementPattern):
|
||||
"""
|
||||
The transformation adjusts Einsum equation to NCHW layout.
|
||||
Subscripts for tensor of rank greater than three must be adjusted
|
||||
to NCHW layout, meaning a label for the last dimension is moved
|
||||
to the second position in the subscript.
|
||||
There is an exception when the last label in the subscript is ellipsis
|
||||
and covers multiple dimensions. In this case subscript is not changed and
|
||||
Transpose to get original NHWC layout back is inserted.
|
||||
The transformation is only applicable to TensorFlow case.
|
||||
"""
|
||||
enabled = True
|
||||
force_shape_inference = True
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
import extensions.middle.InsertLayoutPropagationTransposes as InsertTransposes
|
||||
for einsum in graph.get_op_nodes(type='Einsum'):
|
||||
einsum_name = einsum.soft_get('name', einsum.id)
|
||||
assert einsum.has_valid('equation'), "Equation attribute is mandatory" \
|
||||
" for Einsum node {}".format(einsum_name)
|
||||
equation = einsum.equation
|
||||
connected_in_ports = [port for port in einsum.in_ports().values() if not port.disconnected()]
|
||||
num_inputs = len(connected_in_ports)
|
||||
|
||||
# compute a mask of inputs of rank greater than 3 that are required original layout (NCHW)
|
||||
# due to presence of ellipsis covering multiple tail dimensions in the corresponding input subscript
|
||||
input_ranks = [len(einsum.in_port(port_idx).data.get_shape()) for port_idx in range(num_inputs)]
|
||||
output_rank = len(einsum.out_port(0).data.get_shape())
|
||||
permuted_equation, is_inputs_permuted, is_output_permuted = Einsum.adjust_equation_with_NCHW_layout(
|
||||
einsum_name,
|
||||
equation,
|
||||
input_ranks,
|
||||
output_rank)
|
||||
assert len(is_inputs_permuted) == num_inputs
|
||||
|
||||
# setup adjusted equation
|
||||
einsum.equation = permuted_equation
|
||||
|
||||
# insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation
|
||||
for input_ind in range(num_inputs):
|
||||
if not is_inputs_permuted[input_ind]:
|
||||
# that means Einsum can only accept input in NHWC layout
|
||||
# so the inserted transpose before the Einsum will convert the layout to NHWC
|
||||
InsertTransposes.insert_transpose(graph, einsum.in_port(input_ind), before_input=True)
|
||||
if not is_output_permuted:
|
||||
# that means Einsum can only generate output in NHWC layout
|
||||
# so the inserted transpose followed after the output will convert the layout back into NCHW layout
|
||||
InsertTransposes.insert_transpose(graph, einsum.out_port(0), before_input=False)
|
20
model-optimizer/extensions/front/mxnet/einsum_ext.py
Normal file
20
model-optimizer/extensions/front/mxnet/einsum_ext.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.ops.einsum import Einsum
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
||||
|
||||
|
||||
class EinsumExtractor(FrontExtractorOp):
|
||||
op = '_npi_einsum'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, einsum_node):
|
||||
einsum_name = einsum_node.soft_get('name', einsum_node.id)
|
||||
attrs = get_mxnet_layer_attrs(einsum_node.symbol_dict)
|
||||
equation = attrs.str('subscripts')
|
||||
normalized_equation = Einsum.normalize_equation(einsum_name, equation)
|
||||
Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
|
||||
return cls.enabled
|
19
model-optimizer/extensions/front/onnx/einsum_ext.py
Normal file
19
model-optimizer/extensions/front/onnx/einsum_ext.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.ops.einsum import Einsum
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
|
||||
|
||||
class EinsumExtractor(FrontExtractorOp):
|
||||
op = 'Einsum'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, einsum_node):
|
||||
einsum_name = einsum_node.soft_get('name', einsum_node.id)
|
||||
equation = onnx_attr(einsum_node, 'equation', 's').decode(encoding="utf-8")
|
||||
normalized_equation = Einsum.normalize_equation(einsum_name, equation)
|
||||
Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
|
||||
return cls.enabled
|
18
model-optimizer/extensions/front/tf/einsum_ext.py
Normal file
18
model-optimizer/extensions/front/tf/einsum_ext.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.ops.einsum import Einsum
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
class EinsumExtractor(FrontExtractorOp):
|
||||
op = 'Einsum'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, einsum_node):
|
||||
einsum_name = einsum_node.soft_get('name', einsum_node.id)
|
||||
equation = einsum_node.pb.attr['equation'].s.decode('utf-8')
|
||||
normalized_equation = Einsum.normalize_equation(einsum_name, equation)
|
||||
Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
|
||||
return cls.enabled
|
232
model-optimizer/extensions/ops/einsum.py
Normal file
232
model-optimizer/extensions/ops/einsum.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op
|
||||
from mo.utils.broadcasting import bi_directional_shape_broadcasting
|
||||
|
||||
|
||||
class Einsum(Op):
|
||||
op = 'Einsum'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': self.op,
|
||||
'op': self.op,
|
||||
'version': 'opset7',
|
||||
'infer': self.infer,
|
||||
'out_ports_count': 1,
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
|
||||
def backend_attrs(self):
|
||||
return ['equation']
|
||||
|
||||
@staticmethod
|
||||
def parse_equation(node_name: str, equation: str) -> (list, str):
|
||||
"""
|
||||
Parse Einsum equation and check that its format is correct to make sure that
|
||||
all input subscripts consists of only alphabetic letters or alphabetic letters with one ellipsis.
|
||||
In case of implicit mode the method recovers the right-hand part.
|
||||
|
||||
:param node_name: Einsum node name for which to parse an equation
|
||||
:param equation: Equation to be parsed and checked
|
||||
:return: A tuple of a list of input subscripts and output subscript
|
||||
"""
|
||||
# normalize equation by removing white-spaces
|
||||
equation = equation.strip()
|
||||
|
||||
# split equation into the left and right hands
|
||||
splitted_equation = equation.split('->')
|
||||
assert len(splitted_equation) <= 2, "Einsum node {} has `equation` of incorrect format".format(node_name)
|
||||
|
||||
# split left-hand side of the equation and check a format of input subscripts
|
||||
input_subscripts = splitted_equation[0]
|
||||
input_subscripts_list = input_subscripts.split(',')
|
||||
|
||||
# prepare pattern to check a format of subscripts
|
||||
subscript_pattern = re.compile("^[a-zA-Z]*(\\.\\.\\.){0,1}[a-zA-Z]*$")
|
||||
ellipsis_pattern = re.compile("\\.\\.\\.")
|
||||
|
||||
is_ellipsis_met = False
|
||||
for input_subscript in input_subscripts_list:
|
||||
assert re.match(subscript_pattern, input_subscript) is not None, \
|
||||
"Einsum node {} has `equation` with incorrect input subscript: {}".format(node_name, input_subscript)
|
||||
is_ellipsis_met = is_ellipsis_met or re.search(ellipsis_pattern, input_subscript)
|
||||
|
||||
if len(splitted_equation) == 2:
|
||||
output_subscript = splitted_equation[1]
|
||||
assert re.match(subscript_pattern, output_subscript), \
|
||||
"Einsum node {} has `equation` with incorrect output subscript: {}".format(node_name, output_subscript)
|
||||
# if ellipsis is met, the output subscript must contain it as well
|
||||
if is_ellipsis_met:
|
||||
assert re.search(ellipsis_pattern, output_subscript), \
|
||||
"The output subscript of Einsum node {} must contain ellipsis".format(node_name)
|
||||
elif len(splitted_equation) == 1:
|
||||
# recover output subscript in case implicit mode
|
||||
output_subscript = ''.join(input_subscripts_list)
|
||||
output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'})))
|
||||
if is_ellipsis_met:
|
||||
output_subscript = "..." + output_subscript
|
||||
else:
|
||||
assert False, "Einsum node {} equation has incorrect format. " \
|
||||
"It must be in either explicit or implicit mode.".format(node_name)
|
||||
|
||||
return input_subscripts_list, output_subscript
|
||||
|
||||
@staticmethod
|
||||
def normalize_equation(node_name: str, equation: str) -> str:
|
||||
"""
|
||||
Recover explicit mode of equation.
|
||||
|
||||
:param node_name: Einsum node name for which to recover explicit mode
|
||||
:param equation: Einsum equation to recover explicit mode
|
||||
:return: Recovered equation in explicit mode
|
||||
"""
|
||||
input_subscripts_list, output_subscript = Einsum.parse_equation(node_name, equation)
|
||||
return ','.join(input_subscripts_list) + "->" + output_subscript
|
||||
|
||||
@staticmethod
|
||||
def extract_subscript_labels(node_name: str, subscript: str) -> list:
|
||||
"""
|
||||
Extract labels for given subscript. Each label can be either alphabetic letter or ellipsis
|
||||
|
||||
:param node_name: Einsum node name
|
||||
:param subscript: Given subscript
|
||||
:return: A list of labels
|
||||
"""
|
||||
labels = []
|
||||
len_subscript = len(subscript)
|
||||
label_ind = 0
|
||||
while label_ind < len_subscript:
|
||||
if subscript[label_ind].isalpha():
|
||||
labels.append(subscript[label_ind])
|
||||
label_ind += 1
|
||||
elif len_subscript - label_ind > 2 and subscript[label_ind:label_ind + 3] == "...":
|
||||
labels.append("...")
|
||||
label_ind += 3
|
||||
else:
|
||||
assert False, "Einsum node {} has `equation` with incorrect subscript: {}".format(node_name, subscript)
|
||||
return labels
|
||||
|
||||
@staticmethod
|
||||
def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int) -> (
|
||||
str, list, bool):
|
||||
"""
|
||||
In order to satisfy NCHW layout, subscripts for tensors with rank greater than three must be adjusted by moving labels
|
||||
of the last dimension to the second position in the subscript. There is an exception for such tensors when
|
||||
the label is ellipsis and it covers multiple tail dimensions. The method returns equation with adjusted subscripts
|
||||
to NCHW layout along with a boolean mask to indicate which subscripts are adjusted.
|
||||
|
||||
:param node_name: Einsum node name for which equation is adjusted
|
||||
:param equation: Equation to be adjusted
|
||||
:param input_ranks: a list of input ranks
|
||||
:param output_rank: output rank
|
||||
:return: adjusted equation, boolean mask for inputs, and boolean flag if output subscript is adjusted
|
||||
"""
|
||||
is_inputs_permuted = []
|
||||
input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)
|
||||
num_inputs = len(input_ranks)
|
||||
assert len(input_subscripts) == num_inputs, "The number of inputs must match a number " \
|
||||
"of input subscripts"
|
||||
|
||||
# permute labels in input subscripts and mark inputs for which inference in NCHW layout is acceptable
|
||||
# in case ellipsis covering multiple dimensions in the end, the permutation is impossible
|
||||
# so the corresponding input must be in the original format (NHWC)
|
||||
permuted_input_subscripts = []
|
||||
for input_ind in range(num_inputs):
|
||||
input_subscript = input_subscripts[input_ind]
|
||||
input_rank = input_ranks[input_ind]
|
||||
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
|
||||
num_broadcasted_dims = input_rank - len(labels) + 1
|
||||
if input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
|
||||
is_inputs_permuted.append(True)
|
||||
labels.insert(1, labels[-1])
|
||||
del labels[-1]
|
||||
else:
|
||||
is_inputs_permuted.append(False)
|
||||
permuted_input_subscript = ''.join(labels)
|
||||
permuted_input_subscripts.append(permuted_input_subscript)
|
||||
|
||||
# perform the same procedure for the output subscript as for the inputs subscripts
|
||||
labels = Einsum.extract_subscript_labels(node_name, output_subscript)
|
||||
num_broadcasted_dims = output_rank - len(labels) + 1
|
||||
if output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
|
||||
is_output_permuted = True
|
||||
labels.insert(1, labels[-1])
|
||||
del labels[-1]
|
||||
else:
|
||||
is_output_permuted = False
|
||||
permuted_output_subscript = ''.join(labels)
|
||||
|
||||
# concatenate the left and right hands of the resulted equation
|
||||
left_hand = ','.join(permuted_input_subscripts)
|
||||
right_hand = permuted_output_subscript
|
||||
permuted_equation = left_hand + "->" + right_hand
|
||||
return permuted_equation, is_inputs_permuted, is_output_permuted
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
num_inputs = len(connected_in_ports)
|
||||
assert node.has_valid('equation'), "Einsum node {} must contain `equation` attribute".format(node_name)
|
||||
equation = node.equation
|
||||
|
||||
# parse the equation and extract input and output subscripts
|
||||
input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)
|
||||
|
||||
# check that each operand has the corresponding input subscript
|
||||
assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \
|
||||
"must match the number of input subscripts " \
|
||||
"in `equation`".format(node_name)
|
||||
|
||||
# check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels
|
||||
label_to_shape = {}
|
||||
for input_ind in range(num_inputs):
|
||||
input_shape = node.in_port(input_ind).data.get_shape()
|
||||
input_subscript = input_subscripts[input_ind]
|
||||
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
|
||||
num_dims = len(input_shape)
|
||||
num_labels = len(labels)
|
||||
num_broadcasted_dims = num_dims - num_labels + 1
|
||||
dim_ind = 0
|
||||
label_ind = 0
|
||||
while label_ind < num_labels and dim_ind < num_dims:
|
||||
label = labels[label_ind]
|
||||
if label == "...":
|
||||
sub_shape = input_shape[dim_ind:dim_ind + num_broadcasted_dims]
|
||||
if label in label_to_shape.keys():
|
||||
common_shape = bi_directional_shape_broadcasting(sub_shape, label_to_shape[label])
|
||||
assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \
|
||||
"for Einsum node {}".format(node_name)
|
||||
label_to_shape[label] = common_shape
|
||||
else:
|
||||
label_to_shape[label] = sub_shape
|
||||
dim_ind += num_broadcasted_dims
|
||||
else:
|
||||
dim_size = input_shape[dim_ind]
|
||||
sub_shape = int64_array([dim_size])
|
||||
assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \
|
||||
"Sizes of dimensions with the same label of Einsum node {} " \
|
||||
"must be compatible".format(node_name)
|
||||
label_to_shape[label] = sub_shape
|
||||
dim_ind += 1
|
||||
label_ind += 1
|
||||
|
||||
# generate output shape based on the output subscript
|
||||
output_shape = int64_array([])
|
||||
labels = Einsum.extract_subscript_labels(node_name, output_subscript)
|
||||
for label in labels:
|
||||
assert label in label_to_shape.keys(), "The label in the output subscript must appear" \
|
||||
" in input subscripts in equation {} " \
|
||||
"of Einsum node {}".format(equation, node_name)
|
||||
output_shape = np.concatenate((output_shape, label_to_shape[label]))
|
||||
|
||||
node.out_port(0).data.set_shape(output_shape)
|
@ -0,0 +1,17 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from mo.utils.graph import Node
|
||||
from mo.utils.ir_reader.extender import Extender
|
||||
|
||||
|
||||
class Einsum_extender(Extender):
|
||||
op = 'Einsum'
|
||||
|
||||
@staticmethod
|
||||
def extend(op: Node):
|
||||
einsum_name = op.soft_get('name', op.id)
|
||||
if isinstance(op['equation'], list):
|
||||
op['equation'] = ','.join(op['equation'])
|
||||
elif not isinstance(op['equation'], str):
|
||||
assert False, "Equation of Einsum node {} has incorrect format.".format(einsum_name)
|
@ -0,0 +1,82 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.LayoutChangeForEinsum import LayoutChangeForEinsum
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect
|
||||
|
||||
nodes_attributes = {
|
||||
# Parameter layers
|
||||
**regular_op_with_shaped_data('placeholder_1', None, {'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('placeholder_2', None, {'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('placeholder_3', None, {'type': 'Parameter', 'op': 'Parameter'}),
|
||||
|
||||
# Einsum layer
|
||||
**regular_op_with_shaped_data('einsum', None, {'type': 'Einsum', 'op': 'Einsum'}),
|
||||
|
||||
# Result layer
|
||||
**result(),
|
||||
|
||||
# Transpose layers
|
||||
**regular_op_with_shaped_data('transpose_1', None,
|
||||
{'type': 'Transpose', 'op': 'Transpose', 'need_shape_inference': True}),
|
||||
**regular_op_with_shaped_data('transpose_3', None,
|
||||
{'type': 'Transpose', 'op': 'Transpose', 'need_shape_inference': True}),
|
||||
|
||||
# Const layers
|
||||
**valued_const_with_data('axis_1_const', int64_array([0, 2, 3, 1])),
|
||||
**valued_const_with_data('axis_3_const', int64_array([0, 4, 1, 2, 3])),
|
||||
}
|
||||
|
||||
|
||||
class LayoutChangeForEinsumTests(unittest.TestCase):
|
||||
def test_layout_change_einsum(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[*connect('placeholder_1', '0:einsum'),
|
||||
*connect('placeholder_2', '1:einsum'),
|
||||
*connect('placeholder_3', '2:einsum'),
|
||||
*connect('einsum', 'output')],
|
||||
{ # this input stays as is since it is of a rank equal to 3
|
||||
'placeholder_1_d': {'shape': np.array([2, 3, 5])},
|
||||
# [3, 5, 7, 8] - NHWC, [3, 8, 5, 7] - NCHW
|
||||
# this input does not require additional transpose
|
||||
# since the corresponding subscript can be adjusted
|
||||
'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])},
|
||||
# [3, 5, 10, 12] - NHWC, [3, 12, 5, 10] - NCHW
|
||||
# the third input must be transposed to NHWC layout
|
||||
# since ellipsis covers multiple dimensions in the end
|
||||
# the corresponding subscript is not changed
|
||||
'placeholder_3_d': {'shape': np.array([3, 12, 8, 10])},
|
||||
# equation is still for NHWC layout
|
||||
'einsum': {'equation': "abc,bcde,bc...->ade..."},
|
||||
# [2, 7, 8, 10, 12] - NHWC, [2, 12, 7, 8, 10] - NCHW
|
||||
# the output is in NCHW layout but its shape will be re-inferred since
|
||||
# the output stays in NHWC layout due to ellipsis in the end
|
||||
# and additional transpose to NCHW will be inserted
|
||||
'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])},
|
||||
}, nodes_with_edges_only=True)
|
||||
graph.graph['fw'] = 'tf'
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[*connect('placeholder_3', '0:transpose_1'),
|
||||
*connect('axis_1_const', '1:transpose_1'),
|
||||
*connect('placeholder_1', '0:einsum'),
|
||||
*connect('placeholder_2', '1:einsum'),
|
||||
*connect('transpose_1', '2:einsum'),
|
||||
*connect('einsum', '0:transpose_3'),
|
||||
*connect('axis_3_const', '1:transpose_3'),
|
||||
*connect('transpose_3', 'output')],
|
||||
{'placeholder_1_d': {'shape': np.array([2, 3, 5])},
|
||||
'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])},
|
||||
'einsum': {'equation': "abc,becd,bc...->ade..."},
|
||||
'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])}
|
||||
})
|
||||
|
||||
LayoutChangeForEinsum().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
90
model-optimizer/unit_tests/extensions/ops/einsum_test.py
Normal file
90
model-optimizer/unit_tests/extensions/ops/einsum_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.ops.einsum import Einsum
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Node
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, result, connect
|
||||
|
||||
|
||||
def create_einsum_graph(input_shapes: list, equation: str) -> Graph:
|
||||
num_inputs = len(input_shapes)
|
||||
assert num_inputs > 0, "Einsum node must have at least one input"
|
||||
nodes = {}
|
||||
edges = []
|
||||
for input_ind in range(num_inputs):
|
||||
input_name = 'input' + str(input_ind)
|
||||
parameter_op = regular_op_with_shaped_data(input_name, input_shapes[input_ind],
|
||||
{'op': 'Parameter', 'type': 'Parameter'})
|
||||
nodes.update(parameter_op)
|
||||
edges += connect(input_name, str(input_ind) + ":einsum_node")
|
||||
einsum_op = regular_op_with_shaped_data('einsum_node', None,
|
||||
{'op': 'Einsum', 'type': 'Einsum', 'equation': equation})
|
||||
nodes.update(einsum_op)
|
||||
result_op = result('output')
|
||||
nodes.update(result_op)
|
||||
edges += connect('einsum_node', 'output')
|
||||
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
return graph
|
||||
|
||||
|
||||
@generator
|
||||
class TestEinsum(unittest.TestCase):
|
||||
@generate(*[
|
||||
# dot product
|
||||
([int64_array([10]), int64_array([10])], "i,i->", int64_array([])),
|
||||
# matrix multiplication
|
||||
([int64_array([2, 3]), int64_array([3, 4])], "ab,bc->ac", int64_array([2, 4])),
|
||||
# trace per batch
|
||||
([int64_array([2, 3, 3])], "kii->k", int64_array([2])),
|
||||
# diagonal extraction
|
||||
([int64_array([6, 5, 5])], "kii->ki", int64_array([6, 5])),
|
||||
# transpose
|
||||
([int64_array([1, 2, 3])], "ijk->kij", int64_array([3, 1, 2])),
|
||||
# multiple matrix multiplication
|
||||
([int64_array([2, 5]), int64_array([5, 3, 6]), int64_array([5, 3])], "ab,bcd,bc->ca", int64_array([3, 2])),
|
||||
# ellipsis for one operand
|
||||
([int64_array([5, 3, 4])], "a...->...", int64_array([3, 4])),
|
||||
# ellipsis for multiple operands
|
||||
([int64_array([3, 5]), int64_array([1])], "a...,...->a...", int64_array([3, 5])),
|
||||
# ellipsis with broadcasting
|
||||
([int64_array([9, 1, 4, 3]), int64_array([3, 11, 7, 1])], "a...b,b...->a...", int64_array([9, 11, 7, 4])),
|
||||
# mixed case letters in equation
|
||||
([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])),
|
||||
# mixed case letters and equation in implicit mode
|
||||
([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])),
|
||||
])
|
||||
def test_einsum(self, input_shapes, equation, ref_output_shape):
|
||||
graph = create_einsum_graph(input_shapes, equation)
|
||||
einsum_node = Node(graph, 'einsum_node')
|
||||
Einsum.infer(einsum_node)
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['einsum_node_d']['shape']
|
||||
|
||||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
|
||||
'shape does not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
@generate(*[
|
||||
# incorrect subscript numbers or inputs
|
||||
([int64_array([3, 11]), int64_array([11, 4])], "ab,bc,cd->ac", None),
|
||||
# invalid labels
|
||||
([int64_array([3, 11]), int64_array([11, 4])], "a$,Bc->ac", None),
|
||||
# incompatible shapes
|
||||
([int64_array([3, 11]), int64_array([12, 4])], "ab,bc->ac", None),
|
||||
# not broadcastable shapes
|
||||
([int64_array([11, 1, 4, 3]), int64_array([3, 11, 7, 5])], "a...b,b...->a...", None),
|
||||
# missed ellipsis
|
||||
([int64_array([11, 1, 4, 3]), int64_array([3, 11, 7, 4])], "a...b,b...->a", None),
|
||||
])
|
||||
def test_invalid_cases(self, input_shapes, equation, ref_output_shape):
|
||||
graph = create_einsum_graph(input_shapes, equation)
|
||||
einsum_node = Node(graph, 'einsum_node')
|
||||
self.assertRaises(AssertionError, Einsum.infer, einsum_node)
|
@ -14,7 +14,7 @@ nodes_attributes = {'data': {'kind': 'op'},
|
||||
'data_data': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'indices': {'kind': 'op'},
|
||||
'indices_data': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'gathernd_node': {'op': 'ScatterNDUpdate', 'kind': 'op', 'batch_dims': 0},
|
||||
'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0},
|
||||
'output': {'shape': None, 'value': None, 'kind': 'data'}}
|
||||
|
||||
# graph 1
|
||||
@ -118,7 +118,7 @@ inputs_inv2 = {'data_data': {'shape': int64_array([10, 40, 20]), 'value': None},
|
||||
inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value': None},
|
||||
'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}}
|
||||
|
||||
class TestScatterNDUpdate(unittest.TestCase):
|
||||
class TestGatherNDUpdate(unittest.TestCase):
|
||||
def setUp(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user