Extend MO for operation Einsum-7 (#5401)

* Extend MO for operation Einsum-7

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add extractor for einsum and optimize code based on review feedback

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix the code based on the review: correct code, tests and comments; move insert_transpose

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix LayoutChangeForEinsum transformation condition

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Update third-party dependencies

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2021-05-11 21:36:04 +03:00 committed by GitHub
parent 9db7f849df
commit dc22c177d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 543 additions and 2 deletions

View File

@ -29,6 +29,7 @@ extensions/back/GroupedConvWeightsNormalize.py
extensions/back/insert_compatibility_l2normalization.py
extensions/back/InterpolateReshape.py
extensions/back/kaldi_remove_memory_output.py
extensions/back/LayoutChangeForEinsum.py
extensions/back/LayoutChangeForGatherND.py
extensions/back/LeakyReLUMutation.py
extensions/back/LinearToLinearONNXReplacer.py
@ -193,6 +194,7 @@ extensions/front/mxnet/custom_rpn_proposal.py
extensions/front/mxnet/deformable_conv_ext.py
extensions/front/mxnet/deformable_psroi_pooling_ext.py
extensions/front/mxnet/dropout_ext.py
extensions/front/mxnet/einsum_ext.py
extensions/front/mxnet/elementwise_ext.py
extensions/front/mxnet/eltwise_scalar_replacers.py
extensions/front/mxnet/exp_ext.py
@ -278,6 +280,7 @@ extensions/front/onnx/dequantize_linear_resolver.py
extensions/front/onnx/detection_output.py
extensions/front/onnx/detectionoutput_ext.py
extensions/front/onnx/dropout_ext.py
extensions/front/onnx/einsum_ext.py
extensions/front/onnx/elementwise_ext.py
extensions/front/onnx/expand_ext.py
extensions/front/onnx/faster_rcnn.json
@ -404,6 +407,7 @@ extensions/front/tf/deconv_ext.py
extensions/front/tf/depth_to_space.py
extensions/front/tf/efficient_det_support_api_v2.0.json
extensions/front/tf/efficient_det_support_api_v2.4.json
extensions/front/tf/einsum_ext.py
extensions/front/tf/elementwise_ext.py
extensions/front/tf/embedding_segments_sum.py
extensions/front/tf/expand_dims_ext.py
@ -678,6 +682,7 @@ extensions/ops/dequantize_linear.py
extensions/ops/DetectionOutput.py
extensions/ops/detectionoutput_onnx.py
extensions/ops/dft.py
extensions/ops/einsum.py
extensions/ops/elementwise.py
extensions/ops/embedding_bag.py
extensions/ops/Enter.py
@ -1028,6 +1033,7 @@ mo/utils/ir_reader/extenders/convert_extender.py
mo/utils/ir_reader/extenders/ctc_greedy_decoder_seq_len_extender.py
mo/utils/ir_reader/extenders/deconvolution_extender.py
mo/utils/ir_reader/extenders/deformable_convolution_extender.py
mo/utils/ir_reader/extenders/einsum_extender.py
mo/utils/ir_reader/extenders/experimental_extender.py
mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
mo/utils/ir_reader/extenders/fakequantize_extender.py

View File

@ -0,0 +1,57 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.einsum import Einsum
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
class LayoutChangeForEinsum(BackReplacementPattern):
"""
The transformation adjusts Einsum equation to NCHW layout.
Subscripts for tensor of rank greater than three must be adjusted
to NCHW layout, meaning a label for the last dimension is moved
to the second position in the subscript.
There is an exception when the last label in the subscript is ellipsis
and covers multiple dimensions. In this case subscript is not changed and
Transpose to get original NHWC layout back is inserted.
The transformation is only applicable to TensorFlow case.
"""
enabled = True
force_shape_inference = True
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
def find_and_replace_pattern(self, graph: Graph):
import extensions.middle.InsertLayoutPropagationTransposes as InsertTransposes
for einsum in graph.get_op_nodes(type='Einsum'):
einsum_name = einsum.soft_get('name', einsum.id)
assert einsum.has_valid('equation'), "Equation attribute is mandatory" \
" for Einsum node {}".format(einsum_name)
equation = einsum.equation
connected_in_ports = [port for port in einsum.in_ports().values() if not port.disconnected()]
num_inputs = len(connected_in_ports)
# compute a mask of inputs of rank greater than 3 that are required original layout (NCHW)
# due to presence of ellipsis covering multiple tail dimensions in the corresponding input subscript
input_ranks = [len(einsum.in_port(port_idx).data.get_shape()) for port_idx in range(num_inputs)]
output_rank = len(einsum.out_port(0).data.get_shape())
permuted_equation, is_inputs_permuted, is_output_permuted = Einsum.adjust_equation_with_NCHW_layout(
einsum_name,
equation,
input_ranks,
output_rank)
assert len(is_inputs_permuted) == num_inputs
# setup adjusted equation
einsum.equation = permuted_equation
# insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation
for input_ind in range(num_inputs):
if not is_inputs_permuted[input_ind]:
# that means Einsum can only accept input in NHWC layout
# so the inserted transpose before the Einsum will convert the layout to NHWC
InsertTransposes.insert_transpose(graph, einsum.in_port(input_ind), before_input=True)
if not is_output_permuted:
# that means Einsum can only generate output in NHWC layout
# so the inserted transpose followed after the output will convert the layout back into NCHW layout
InsertTransposes.insert_transpose(graph, einsum.out_port(0), before_input=False)

View File

@ -0,0 +1,20 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.einsum import Einsum
from mo.front.extractor import FrontExtractorOp
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
class EinsumExtractor(FrontExtractorOp):
op = '_npi_einsum'
enabled = True
@classmethod
def extract(cls, einsum_node):
einsum_name = einsum_node.soft_get('name', einsum_node.id)
attrs = get_mxnet_layer_attrs(einsum_node.symbol_dict)
equation = attrs.str('subscripts')
normalized_equation = Einsum.normalize_equation(einsum_name, equation)
Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
return cls.enabled

View File

@ -0,0 +1,19 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.einsum import Einsum
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
class EinsumExtractor(FrontExtractorOp):
op = 'Einsum'
enabled = True
@classmethod
def extract(cls, einsum_node):
einsum_name = einsum_node.soft_get('name', einsum_node.id)
equation = onnx_attr(einsum_node, 'equation', 's').decode(encoding="utf-8")
normalized_equation = Einsum.normalize_equation(einsum_name, equation)
Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
return cls.enabled

View File

@ -0,0 +1,18 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.einsum import Einsum
from mo.front.extractor import FrontExtractorOp
class EinsumExtractor(FrontExtractorOp):
op = 'Einsum'
enabled = True
@classmethod
def extract(cls, einsum_node):
einsum_name = einsum_node.soft_get('name', einsum_node.id)
equation = einsum_node.pb.attr['equation'].s.decode('utf-8')
normalized_equation = Einsum.normalize_equation(einsum_name, equation)
Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
return cls.enabled

View File

@ -0,0 +1,232 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import re
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
from mo.utils.broadcasting import bi_directional_shape_broadcasting
class Einsum(Op):
op = 'Einsum'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': self.op,
'op': self.op,
'version': 'opset7',
'infer': self.infer,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)
def backend_attrs(self):
return ['equation']
@staticmethod
def parse_equation(node_name: str, equation: str) -> (list, str):
"""
Parse Einsum equation and check that its format is correct to make sure that
all input subscripts consists of only alphabetic letters or alphabetic letters with one ellipsis.
In case of implicit mode the method recovers the right-hand part.
:param node_name: Einsum node name for which to parse an equation
:param equation: Equation to be parsed and checked
:return: A tuple of a list of input subscripts and output subscript
"""
# normalize equation by removing white-spaces
equation = equation.strip()
# split equation into the left and right hands
splitted_equation = equation.split('->')
assert len(splitted_equation) <= 2, "Einsum node {} has `equation` of incorrect format".format(node_name)
# split left-hand side of the equation and check a format of input subscripts
input_subscripts = splitted_equation[0]
input_subscripts_list = input_subscripts.split(',')
# prepare pattern to check a format of subscripts
subscript_pattern = re.compile("^[a-zA-Z]*(\\.\\.\\.){0,1}[a-zA-Z]*$")
ellipsis_pattern = re.compile("\\.\\.\\.")
is_ellipsis_met = False
for input_subscript in input_subscripts_list:
assert re.match(subscript_pattern, input_subscript) is not None, \
"Einsum node {} has `equation` with incorrect input subscript: {}".format(node_name, input_subscript)
is_ellipsis_met = is_ellipsis_met or re.search(ellipsis_pattern, input_subscript)
if len(splitted_equation) == 2:
output_subscript = splitted_equation[1]
assert re.match(subscript_pattern, output_subscript), \
"Einsum node {} has `equation` with incorrect output subscript: {}".format(node_name, output_subscript)
# if ellipsis is met, the output subscript must contain it as well
if is_ellipsis_met:
assert re.search(ellipsis_pattern, output_subscript), \
"The output subscript of Einsum node {} must contain ellipsis".format(node_name)
elif len(splitted_equation) == 1:
# recover output subscript in case implicit mode
output_subscript = ''.join(input_subscripts_list)
output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'})))
if is_ellipsis_met:
output_subscript = "..." + output_subscript
else:
assert False, "Einsum node {} equation has incorrect format. " \
"It must be in either explicit or implicit mode.".format(node_name)
return input_subscripts_list, output_subscript
@staticmethod
def normalize_equation(node_name: str, equation: str) -> str:
"""
Recover explicit mode of equation.
:param node_name: Einsum node name for which to recover explicit mode
:param equation: Einsum equation to recover explicit mode
:return: Recovered equation in explicit mode
"""
input_subscripts_list, output_subscript = Einsum.parse_equation(node_name, equation)
return ','.join(input_subscripts_list) + "->" + output_subscript
@staticmethod
def extract_subscript_labels(node_name: str, subscript: str) -> list:
"""
Extract labels for given subscript. Each label can be either alphabetic letter or ellipsis
:param node_name: Einsum node name
:param subscript: Given subscript
:return: A list of labels
"""
labels = []
len_subscript = len(subscript)
label_ind = 0
while label_ind < len_subscript:
if subscript[label_ind].isalpha():
labels.append(subscript[label_ind])
label_ind += 1
elif len_subscript - label_ind > 2 and subscript[label_ind:label_ind + 3] == "...":
labels.append("...")
label_ind += 3
else:
assert False, "Einsum node {} has `equation` with incorrect subscript: {}".format(node_name, subscript)
return labels
@staticmethod
def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int) -> (
str, list, bool):
"""
In order to satisfy NCHW layout, subscripts for tensors with rank greater than three must be adjusted by moving labels
of the last dimension to the second position in the subscript. There is an exception for such tensors when
the label is ellipsis and it covers multiple tail dimensions. The method returns equation with adjusted subscripts
to NCHW layout along with a boolean mask to indicate which subscripts are adjusted.
:param node_name: Einsum node name for which equation is adjusted
:param equation: Equation to be adjusted
:param input_ranks: a list of input ranks
:param output_rank: output rank
:return: adjusted equation, boolean mask for inputs, and boolean flag if output subscript is adjusted
"""
is_inputs_permuted = []
input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)
num_inputs = len(input_ranks)
assert len(input_subscripts) == num_inputs, "The number of inputs must match a number " \
"of input subscripts"
# permute labels in input subscripts and mark inputs for which inference in NCHW layout is acceptable
# in case ellipsis covering multiple dimensions in the end, the permutation is impossible
# so the corresponding input must be in the original format (NHWC)
permuted_input_subscripts = []
for input_ind in range(num_inputs):
input_subscript = input_subscripts[input_ind]
input_rank = input_ranks[input_ind]
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
num_broadcasted_dims = input_rank - len(labels) + 1
if input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_inputs_permuted.append(True)
labels.insert(1, labels[-1])
del labels[-1]
else:
is_inputs_permuted.append(False)
permuted_input_subscript = ''.join(labels)
permuted_input_subscripts.append(permuted_input_subscript)
# perform the same procedure for the output subscript as for the inputs subscripts
labels = Einsum.extract_subscript_labels(node_name, output_subscript)
num_broadcasted_dims = output_rank - len(labels) + 1
if output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_output_permuted = True
labels.insert(1, labels[-1])
del labels[-1]
else:
is_output_permuted = False
permuted_output_subscript = ''.join(labels)
# concatenate the left and right hands of the resulted equation
left_hand = ','.join(permuted_input_subscripts)
right_hand = permuted_output_subscript
permuted_equation = left_hand + "->" + right_hand
return permuted_equation, is_inputs_permuted, is_output_permuted
@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
num_inputs = len(connected_in_ports)
assert node.has_valid('equation'), "Einsum node {} must contain `equation` attribute".format(node_name)
equation = node.equation
# parse the equation and extract input and output subscripts
input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)
# check that each operand has the corresponding input subscript
assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \
"must match the number of input subscripts " \
"in `equation`".format(node_name)
# check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels
label_to_shape = {}
for input_ind in range(num_inputs):
input_shape = node.in_port(input_ind).data.get_shape()
input_subscript = input_subscripts[input_ind]
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
num_dims = len(input_shape)
num_labels = len(labels)
num_broadcasted_dims = num_dims - num_labels + 1
dim_ind = 0
label_ind = 0
while label_ind < num_labels and dim_ind < num_dims:
label = labels[label_ind]
if label == "...":
sub_shape = input_shape[dim_ind:dim_ind + num_broadcasted_dims]
if label in label_to_shape.keys():
common_shape = bi_directional_shape_broadcasting(sub_shape, label_to_shape[label])
assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \
"for Einsum node {}".format(node_name)
label_to_shape[label] = common_shape
else:
label_to_shape[label] = sub_shape
dim_ind += num_broadcasted_dims
else:
dim_size = input_shape[dim_ind]
sub_shape = int64_array([dim_size])
assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \
"Sizes of dimensions with the same label of Einsum node {} " \
"must be compatible".format(node_name)
label_to_shape[label] = sub_shape
dim_ind += 1
label_ind += 1
# generate output shape based on the output subscript
output_shape = int64_array([])
labels = Einsum.extract_subscript_labels(node_name, output_subscript)
for label in labels:
assert label in label_to_shape.keys(), "The label in the output subscript must appear" \
" in input subscripts in equation {} " \
"of Einsum node {}".format(equation, node_name)
output_shape = np.concatenate((output_shape, label_to_shape[label]))
node.out_port(0).data.set_shape(output_shape)

View File

@ -0,0 +1,17 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.utils.graph import Node
from mo.utils.ir_reader.extender import Extender
class Einsum_extender(Extender):
op = 'Einsum'
@staticmethod
def extend(op: Node):
einsum_name = op.soft_get('name', op.id)
if isinstance(op['equation'], list):
op['equation'] = ','.join(op['equation'])
elif not isinstance(op['equation'], str):
assert False, "Equation of Einsum node {} has incorrect format.".format(einsum_name)

View File

@ -0,0 +1,82 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.back.LayoutChangeForEinsum import LayoutChangeForEinsum
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect
nodes_attributes = {
# Parameter layers
**regular_op_with_shaped_data('placeholder_1', None, {'type': 'Parameter', 'op': 'Parameter'}),
**regular_op_with_shaped_data('placeholder_2', None, {'type': 'Parameter', 'op': 'Parameter'}),
**regular_op_with_shaped_data('placeholder_3', None, {'type': 'Parameter', 'op': 'Parameter'}),
# Einsum layer
**regular_op_with_shaped_data('einsum', None, {'type': 'Einsum', 'op': 'Einsum'}),
# Result layer
**result(),
# Transpose layers
**regular_op_with_shaped_data('transpose_1', None,
{'type': 'Transpose', 'op': 'Transpose', 'need_shape_inference': True}),
**regular_op_with_shaped_data('transpose_3', None,
{'type': 'Transpose', 'op': 'Transpose', 'need_shape_inference': True}),
# Const layers
**valued_const_with_data('axis_1_const', int64_array([0, 2, 3, 1])),
**valued_const_with_data('axis_3_const', int64_array([0, 4, 1, 2, 3])),
}
class LayoutChangeForEinsumTests(unittest.TestCase):
def test_layout_change_einsum(self):
graph = build_graph(nodes_attributes,
[*connect('placeholder_1', '0:einsum'),
*connect('placeholder_2', '1:einsum'),
*connect('placeholder_3', '2:einsum'),
*connect('einsum', 'output')],
{ # this input stays as is since it is of a rank equal to 3
'placeholder_1_d': {'shape': np.array([2, 3, 5])},
# [3, 5, 7, 8] - NHWC, [3, 8, 5, 7] - NCHW
# this input does not require additional transpose
# since the corresponding subscript can be adjusted
'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])},
# [3, 5, 10, 12] - NHWC, [3, 12, 5, 10] - NCHW
# the third input must be transposed to NHWC layout
# since ellipsis covers multiple dimensions in the end
# the corresponding subscript is not changed
'placeholder_3_d': {'shape': np.array([3, 12, 8, 10])},
# equation is still for NHWC layout
'einsum': {'equation': "abc,bcde,bc...->ade..."},
# [2, 7, 8, 10, 12] - NHWC, [2, 12, 7, 8, 10] - NCHW
# the output is in NCHW layout but its shape will be re-inferred since
# the output stays in NHWC layout due to ellipsis in the end
# and additional transpose to NCHW will be inserted
'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])},
}, nodes_with_edges_only=True)
graph.graph['fw'] = 'tf'
graph_ref = build_graph(nodes_attributes,
[*connect('placeholder_3', '0:transpose_1'),
*connect('axis_1_const', '1:transpose_1'),
*connect('placeholder_1', '0:einsum'),
*connect('placeholder_2', '1:einsum'),
*connect('transpose_1', '2:einsum'),
*connect('einsum', '0:transpose_3'),
*connect('axis_3_const', '1:transpose_3'),
*connect('transpose_3', 'output')],
{'placeholder_1_d': {'shape': np.array([2, 3, 5])},
'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])},
'einsum': {'equation': "abc,becd,bc...->ade..."},
'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])}
})
LayoutChangeForEinsum().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -0,0 +1,90 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from generator import generator, generate
from extensions.ops.einsum import Einsum
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph
from mo.graph.graph import Node
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, result, connect
def create_einsum_graph(input_shapes: list, equation: str) -> Graph:
num_inputs = len(input_shapes)
assert num_inputs > 0, "Einsum node must have at least one input"
nodes = {}
edges = []
for input_ind in range(num_inputs):
input_name = 'input' + str(input_ind)
parameter_op = regular_op_with_shaped_data(input_name, input_shapes[input_ind],
{'op': 'Parameter', 'type': 'Parameter'})
nodes.update(parameter_op)
edges += connect(input_name, str(input_ind) + ":einsum_node")
einsum_op = regular_op_with_shaped_data('einsum_node', None,
{'op': 'Einsum', 'type': 'Einsum', 'equation': equation})
nodes.update(einsum_op)
result_op = result('output')
nodes.update(result_op)
edges += connect('einsum_node', 'output')
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
return graph
@generator
class TestEinsum(unittest.TestCase):
@generate(*[
# dot product
([int64_array([10]), int64_array([10])], "i,i->", int64_array([])),
# matrix multiplication
([int64_array([2, 3]), int64_array([3, 4])], "ab,bc->ac", int64_array([2, 4])),
# trace per batch
([int64_array([2, 3, 3])], "kii->k", int64_array([2])),
# diagonal extraction
([int64_array([6, 5, 5])], "kii->ki", int64_array([6, 5])),
# transpose
([int64_array([1, 2, 3])], "ijk->kij", int64_array([3, 1, 2])),
# multiple matrix multiplication
([int64_array([2, 5]), int64_array([5, 3, 6]), int64_array([5, 3])], "ab,bcd,bc->ca", int64_array([3, 2])),
# ellipsis for one operand
([int64_array([5, 3, 4])], "a...->...", int64_array([3, 4])),
# ellipsis for multiple operands
([int64_array([3, 5]), int64_array([1])], "a...,...->a...", int64_array([3, 5])),
# ellipsis with broadcasting
([int64_array([9, 1, 4, 3]), int64_array([3, 11, 7, 1])], "a...b,b...->a...", int64_array([9, 11, 7, 4])),
# mixed case letters in equation
([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])),
# mixed case letters and equation in implicit mode
([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])),
])
def test_einsum(self, input_shapes, equation, ref_output_shape):
graph = create_einsum_graph(input_shapes, equation)
einsum_node = Node(graph, 'einsum_node')
Einsum.infer(einsum_node)
# get the result
res_output_shape = graph.node['einsum_node_d']['shape']
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'shape does not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
@generate(*[
# incorrect subscript numbers or inputs
([int64_array([3, 11]), int64_array([11, 4])], "ab,bc,cd->ac", None),
# invalid labels
([int64_array([3, 11]), int64_array([11, 4])], "a$,Bc->ac", None),
# incompatible shapes
([int64_array([3, 11]), int64_array([12, 4])], "ab,bc->ac", None),
# not broadcastable shapes
([int64_array([11, 1, 4, 3]), int64_array([3, 11, 7, 5])], "a...b,b...->a...", None),
# missed ellipsis
([int64_array([11, 1, 4, 3]), int64_array([3, 11, 7, 4])], "a...b,b...->a", None),
])
def test_invalid_cases(self, input_shapes, equation, ref_output_shape):
graph = create_einsum_graph(input_shapes, equation)
einsum_node = Node(graph, 'einsum_node')
self.assertRaises(AssertionError, Einsum.infer, einsum_node)

View File

@ -14,7 +14,7 @@ nodes_attributes = {'data': {'kind': 'op'},
'data_data': {'shape': None, 'value': None, 'kind': 'data'},
'indices': {'kind': 'op'},
'indices_data': {'shape': None, 'value': None, 'kind': 'data'},
'gathernd_node': {'op': 'ScatterNDUpdate', 'kind': 'op', 'batch_dims': 0},
'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0},
'output': {'shape': None, 'value': None, 'kind': 'data'}}
# graph 1
@ -118,7 +118,7 @@ inputs_inv2 = {'data_data': {'shape': int64_array([10, 40, 20]), 'value': None},
inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value': None},
'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}}
class TestScatterNDUpdate(unittest.TestCase):
class TestGatherNDUpdate(unittest.TestCase):
def setUp(self):
nodes_attributes['gathernd_node']['batch_dims'] = 0