diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index d8b90b715a4..d236df8f3d2 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -29,6 +29,7 @@ extensions/back/GroupedConvWeightsNormalize.py extensions/back/insert_compatibility_l2normalization.py extensions/back/InterpolateReshape.py extensions/back/kaldi_remove_memory_output.py +extensions/back/LayoutChangeForEinsum.py extensions/back/LayoutChangeForGatherND.py extensions/back/LeakyReLUMutation.py extensions/back/LinearToLinearONNXReplacer.py @@ -193,6 +194,7 @@ extensions/front/mxnet/custom_rpn_proposal.py extensions/front/mxnet/deformable_conv_ext.py extensions/front/mxnet/deformable_psroi_pooling_ext.py extensions/front/mxnet/dropout_ext.py +extensions/front/mxnet/einsum_ext.py extensions/front/mxnet/elementwise_ext.py extensions/front/mxnet/eltwise_scalar_replacers.py extensions/front/mxnet/exp_ext.py @@ -278,6 +280,7 @@ extensions/front/onnx/dequantize_linear_resolver.py extensions/front/onnx/detection_output.py extensions/front/onnx/detectionoutput_ext.py extensions/front/onnx/dropout_ext.py +extensions/front/onnx/einsum_ext.py extensions/front/onnx/elementwise_ext.py extensions/front/onnx/expand_ext.py extensions/front/onnx/faster_rcnn.json @@ -404,6 +407,7 @@ extensions/front/tf/deconv_ext.py extensions/front/tf/depth_to_space.py extensions/front/tf/efficient_det_support_api_v2.0.json extensions/front/tf/efficient_det_support_api_v2.4.json +extensions/front/tf/einsum_ext.py extensions/front/tf/elementwise_ext.py extensions/front/tf/embedding_segments_sum.py extensions/front/tf/expand_dims_ext.py @@ -678,6 +682,7 @@ extensions/ops/dequantize_linear.py extensions/ops/DetectionOutput.py extensions/ops/detectionoutput_onnx.py extensions/ops/dft.py +extensions/ops/einsum.py extensions/ops/elementwise.py extensions/ops/embedding_bag.py extensions/ops/Enter.py @@ -1028,6 +1033,7 @@ mo/utils/ir_reader/extenders/convert_extender.py mo/utils/ir_reader/extenders/ctc_greedy_decoder_seq_len_extender.py mo/utils/ir_reader/extenders/deconvolution_extender.py mo/utils/ir_reader/extenders/deformable_convolution_extender.py +mo/utils/ir_reader/extenders/einsum_extender.py mo/utils/ir_reader/extenders/experimental_extender.py mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py mo/utils/ir_reader/extenders/fakequantize_extender.py diff --git a/model-optimizer/extensions/back/LayoutChangeForEinsum.py b/model-optimizer/extensions/back/LayoutChangeForEinsum.py new file mode 100644 index 00000000000..f45bff54b93 --- /dev/null +++ b/model-optimizer/extensions/back/LayoutChangeForEinsum.py @@ -0,0 +1,57 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from extensions.ops.einsum import Einsum +from mo.back.replacement import BackReplacementPattern +from mo.graph.graph import Graph + + +class LayoutChangeForEinsum(BackReplacementPattern): + """ + The transformation adjusts Einsum equation to NCHW layout. + Subscripts for tensor of rank greater than three must be adjusted + to NCHW layout, meaning a label for the last dimension is moved + to the second position in the subscript. + There is an exception when the last label in the subscript is ellipsis + and covers multiple dimensions. In this case subscript is not changed and + Transpose to get original NHWC layout back is inserted. + The transformation is only applicable to TensorFlow case. + """ + enabled = True + force_shape_inference = True + graph_condition = [lambda graph: graph.graph['fw'] == 'tf'] + + def find_and_replace_pattern(self, graph: Graph): + import extensions.middle.InsertLayoutPropagationTransposes as InsertTransposes + for einsum in graph.get_op_nodes(type='Einsum'): + einsum_name = einsum.soft_get('name', einsum.id) + assert einsum.has_valid('equation'), "Equation attribute is mandatory" \ + " for Einsum node {}".format(einsum_name) + equation = einsum.equation + connected_in_ports = [port for port in einsum.in_ports().values() if not port.disconnected()] + num_inputs = len(connected_in_ports) + + # compute a mask of inputs of rank greater than 3 that are required original layout (NCHW) + # due to presence of ellipsis covering multiple tail dimensions in the corresponding input subscript + input_ranks = [len(einsum.in_port(port_idx).data.get_shape()) for port_idx in range(num_inputs)] + output_rank = len(einsum.out_port(0).data.get_shape()) + permuted_equation, is_inputs_permuted, is_output_permuted = Einsum.adjust_equation_with_NCHW_layout( + einsum_name, + equation, + input_ranks, + output_rank) + assert len(is_inputs_permuted) == num_inputs + + # setup adjusted equation + einsum.equation = permuted_equation + + # insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation + for input_ind in range(num_inputs): + if not is_inputs_permuted[input_ind]: + # that means Einsum can only accept input in NHWC layout + # so the inserted transpose before the Einsum will convert the layout to NHWC + InsertTransposes.insert_transpose(graph, einsum.in_port(input_ind), before_input=True) + if not is_output_permuted: + # that means Einsum can only generate output in NHWC layout + # so the inserted transpose followed after the output will convert the layout back into NCHW layout + InsertTransposes.insert_transpose(graph, einsum.out_port(0), before_input=False) diff --git a/model-optimizer/extensions/front/mxnet/einsum_ext.py b/model-optimizer/extensions/front/mxnet/einsum_ext.py new file mode 100644 index 00000000000..445cc38f242 --- /dev/null +++ b/model-optimizer/extensions/front/mxnet/einsum_ext.py @@ -0,0 +1,20 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from extensions.ops.einsum import Einsum +from mo.front.extractor import FrontExtractorOp +from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs + + +class EinsumExtractor(FrontExtractorOp): + op = '_npi_einsum' + enabled = True + + @classmethod + def extract(cls, einsum_node): + einsum_name = einsum_node.soft_get('name', einsum_node.id) + attrs = get_mxnet_layer_attrs(einsum_node.symbol_dict) + equation = attrs.str('subscripts') + normalized_equation = Einsum.normalize_equation(einsum_name, equation) + Einsum.update_node_stat(einsum_node, {'equation': normalized_equation}) + return cls.enabled diff --git a/model-optimizer/extensions/front/onnx/einsum_ext.py b/model-optimizer/extensions/front/onnx/einsum_ext.py new file mode 100644 index 00000000000..f2a0336ab65 --- /dev/null +++ b/model-optimizer/extensions/front/onnx/einsum_ext.py @@ -0,0 +1,19 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from extensions.ops.einsum import Einsum +from mo.front.extractor import FrontExtractorOp +from mo.front.onnx.extractors.utils import onnx_attr + + +class EinsumExtractor(FrontExtractorOp): + op = 'Einsum' + enabled = True + + @classmethod + def extract(cls, einsum_node): + einsum_name = einsum_node.soft_get('name', einsum_node.id) + equation = onnx_attr(einsum_node, 'equation', 's').decode(encoding="utf-8") + normalized_equation = Einsum.normalize_equation(einsum_name, equation) + Einsum.update_node_stat(einsum_node, {'equation': normalized_equation}) + return cls.enabled diff --git a/model-optimizer/extensions/front/tf/einsum_ext.py b/model-optimizer/extensions/front/tf/einsum_ext.py new file mode 100644 index 00000000000..290fa948183 --- /dev/null +++ b/model-optimizer/extensions/front/tf/einsum_ext.py @@ -0,0 +1,18 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from extensions.ops.einsum import Einsum +from mo.front.extractor import FrontExtractorOp + + +class EinsumExtractor(FrontExtractorOp): + op = 'Einsum' + enabled = True + + @classmethod + def extract(cls, einsum_node): + einsum_name = einsum_node.soft_get('name', einsum_node.id) + equation = einsum_node.pb.attr['equation'].s.decode('utf-8') + normalized_equation = Einsum.normalize_equation(einsum_name, equation) + Einsum.update_node_stat(einsum_node, {'equation': normalized_equation}) + return cls.enabled diff --git a/model-optimizer/extensions/ops/einsum.py b/model-optimizer/extensions/ops/einsum.py new file mode 100644 index 00000000000..907989216ac --- /dev/null +++ b/model-optimizer/extensions/ops/einsum.py @@ -0,0 +1,232 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import re + +import numpy as np + +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Node, Graph +from mo.ops.op import Op +from mo.utils.broadcasting import bi_directional_shape_broadcasting + + +class Einsum(Op): + op = 'Einsum' + enabled = False + + def __init__(self, graph: Graph, attrs: dict): + mandatory_props = { + 'type': self.op, + 'op': self.op, + 'version': 'opset7', + 'infer': self.infer, + 'out_ports_count': 1, + } + super().__init__(graph, mandatory_props, attrs) + + def backend_attrs(self): + return ['equation'] + + @staticmethod + def parse_equation(node_name: str, equation: str) -> (list, str): + """ + Parse Einsum equation and check that its format is correct to make sure that + all input subscripts consists of only alphabetic letters or alphabetic letters with one ellipsis. + In case of implicit mode the method recovers the right-hand part. + + :param node_name: Einsum node name for which to parse an equation + :param equation: Equation to be parsed and checked + :return: A tuple of a list of input subscripts and output subscript + """ + # normalize equation by removing white-spaces + equation = equation.strip() + + # split equation into the left and right hands + splitted_equation = equation.split('->') + assert len(splitted_equation) <= 2, "Einsum node {} has `equation` of incorrect format".format(node_name) + + # split left-hand side of the equation and check a format of input subscripts + input_subscripts = splitted_equation[0] + input_subscripts_list = input_subscripts.split(',') + + # prepare pattern to check a format of subscripts + subscript_pattern = re.compile("^[a-zA-Z]*(\\.\\.\\.){0,1}[a-zA-Z]*$") + ellipsis_pattern = re.compile("\\.\\.\\.") + + is_ellipsis_met = False + for input_subscript in input_subscripts_list: + assert re.match(subscript_pattern, input_subscript) is not None, \ + "Einsum node {} has `equation` with incorrect input subscript: {}".format(node_name, input_subscript) + is_ellipsis_met = is_ellipsis_met or re.search(ellipsis_pattern, input_subscript) + + if len(splitted_equation) == 2: + output_subscript = splitted_equation[1] + assert re.match(subscript_pattern, output_subscript), \ + "Einsum node {} has `equation` with incorrect output subscript: {}".format(node_name, output_subscript) + # if ellipsis is met, the output subscript must contain it as well + if is_ellipsis_met: + assert re.search(ellipsis_pattern, output_subscript), \ + "The output subscript of Einsum node {} must contain ellipsis".format(node_name) + elif len(splitted_equation) == 1: + # recover output subscript in case implicit mode + output_subscript = ''.join(input_subscripts_list) + output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'}))) + if is_ellipsis_met: + output_subscript = "..." + output_subscript + else: + assert False, "Einsum node {} equation has incorrect format. " \ + "It must be in either explicit or implicit mode.".format(node_name) + + return input_subscripts_list, output_subscript + + @staticmethod + def normalize_equation(node_name: str, equation: str) -> str: + """ + Recover explicit mode of equation. + + :param node_name: Einsum node name for which to recover explicit mode + :param equation: Einsum equation to recover explicit mode + :return: Recovered equation in explicit mode + """ + input_subscripts_list, output_subscript = Einsum.parse_equation(node_name, equation) + return ','.join(input_subscripts_list) + "->" + output_subscript + + @staticmethod + def extract_subscript_labels(node_name: str, subscript: str) -> list: + """ + Extract labels for given subscript. Each label can be either alphabetic letter or ellipsis + + :param node_name: Einsum node name + :param subscript: Given subscript + :return: A list of labels + """ + labels = [] + len_subscript = len(subscript) + label_ind = 0 + while label_ind < len_subscript: + if subscript[label_ind].isalpha(): + labels.append(subscript[label_ind]) + label_ind += 1 + elif len_subscript - label_ind > 2 and subscript[label_ind:label_ind + 3] == "...": + labels.append("...") + label_ind += 3 + else: + assert False, "Einsum node {} has `equation` with incorrect subscript: {}".format(node_name, subscript) + return labels + + @staticmethod + def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int) -> ( + str, list, bool): + """ + In order to satisfy NCHW layout, subscripts for tensors with rank greater than three must be adjusted by moving labels + of the last dimension to the second position in the subscript. There is an exception for such tensors when + the label is ellipsis and it covers multiple tail dimensions. The method returns equation with adjusted subscripts + to NCHW layout along with a boolean mask to indicate which subscripts are adjusted. + + :param node_name: Einsum node name for which equation is adjusted + :param equation: Equation to be adjusted + :param input_ranks: a list of input ranks + :param output_rank: output rank + :return: adjusted equation, boolean mask for inputs, and boolean flag if output subscript is adjusted + """ + is_inputs_permuted = [] + input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation) + num_inputs = len(input_ranks) + assert len(input_subscripts) == num_inputs, "The number of inputs must match a number " \ + "of input subscripts" + + # permute labels in input subscripts and mark inputs for which inference in NCHW layout is acceptable + # in case ellipsis covering multiple dimensions in the end, the permutation is impossible + # so the corresponding input must be in the original format (NHWC) + permuted_input_subscripts = [] + for input_ind in range(num_inputs): + input_subscript = input_subscripts[input_ind] + input_rank = input_ranks[input_ind] + labels = Einsum.extract_subscript_labels(node_name, input_subscript) + num_broadcasted_dims = input_rank - len(labels) + 1 + if input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1): + is_inputs_permuted.append(True) + labels.insert(1, labels[-1]) + del labels[-1] + else: + is_inputs_permuted.append(False) + permuted_input_subscript = ''.join(labels) + permuted_input_subscripts.append(permuted_input_subscript) + + # perform the same procedure for the output subscript as for the inputs subscripts + labels = Einsum.extract_subscript_labels(node_name, output_subscript) + num_broadcasted_dims = output_rank - len(labels) + 1 + if output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1): + is_output_permuted = True + labels.insert(1, labels[-1]) + del labels[-1] + else: + is_output_permuted = False + permuted_output_subscript = ''.join(labels) + + # concatenate the left and right hands of the resulted equation + left_hand = ','.join(permuted_input_subscripts) + right_hand = permuted_output_subscript + permuted_equation = left_hand + "->" + right_hand + return permuted_equation, is_inputs_permuted, is_output_permuted + + @staticmethod + def infer(node: Node): + node_name = node.soft_get('name', node.id) + connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] + num_inputs = len(connected_in_ports) + assert node.has_valid('equation'), "Einsum node {} must contain `equation` attribute".format(node_name) + equation = node.equation + + # parse the equation and extract input and output subscripts + input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation) + + # check that each operand has the corresponding input subscript + assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \ + "must match the number of input subscripts " \ + "in `equation`".format(node_name) + + # check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels + label_to_shape = {} + for input_ind in range(num_inputs): + input_shape = node.in_port(input_ind).data.get_shape() + input_subscript = input_subscripts[input_ind] + labels = Einsum.extract_subscript_labels(node_name, input_subscript) + num_dims = len(input_shape) + num_labels = len(labels) + num_broadcasted_dims = num_dims - num_labels + 1 + dim_ind = 0 + label_ind = 0 + while label_ind < num_labels and dim_ind < num_dims: + label = labels[label_ind] + if label == "...": + sub_shape = input_shape[dim_ind:dim_ind + num_broadcasted_dims] + if label in label_to_shape.keys(): + common_shape = bi_directional_shape_broadcasting(sub_shape, label_to_shape[label]) + assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \ + "for Einsum node {}".format(node_name) + label_to_shape[label] = common_shape + else: + label_to_shape[label] = sub_shape + dim_ind += num_broadcasted_dims + else: + dim_size = input_shape[dim_ind] + sub_shape = int64_array([dim_size]) + assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \ + "Sizes of dimensions with the same label of Einsum node {} " \ + "must be compatible".format(node_name) + label_to_shape[label] = sub_shape + dim_ind += 1 + label_ind += 1 + + # generate output shape based on the output subscript + output_shape = int64_array([]) + labels = Einsum.extract_subscript_labels(node_name, output_subscript) + for label in labels: + assert label in label_to_shape.keys(), "The label in the output subscript must appear" \ + " in input subscripts in equation {} " \ + "of Einsum node {}".format(equation, node_name) + output_shape = np.concatenate((output_shape, label_to_shape[label])) + + node.out_port(0).data.set_shape(output_shape) diff --git a/model-optimizer/mo/utils/ir_reader/extenders/einsum_extender.py b/model-optimizer/mo/utils/ir_reader/extenders/einsum_extender.py new file mode 100644 index 00000000000..1cb3e35bc52 --- /dev/null +++ b/model-optimizer/mo/utils/ir_reader/extenders/einsum_extender.py @@ -0,0 +1,17 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from mo.utils.graph import Node +from mo.utils.ir_reader.extender import Extender + + +class Einsum_extender(Extender): + op = 'Einsum' + + @staticmethod + def extend(op: Node): + einsum_name = op.soft_get('name', op.id) + if isinstance(op['equation'], list): + op['equation'] = ','.join(op['equation']) + elif not isinstance(op['equation'], str): + assert False, "Equation of Einsum node {} has incorrect format.".format(einsum_name) diff --git a/model-optimizer/unit_tests/extensions/back/LayoutChangeForEinsum_test.py b/model-optimizer/unit_tests/extensions/back/LayoutChangeForEinsum_test.py new file mode 100644 index 00000000000..45e0f2badab --- /dev/null +++ b/model-optimizer/unit_tests/extensions/back/LayoutChangeForEinsum_test.py @@ -0,0 +1,82 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np + +from extensions.back.LayoutChangeForEinsum import LayoutChangeForEinsum +from mo.front.common.partial_infer.utils import int64_array +from mo.utils.ir_engine.compare_graphs import compare_graphs +from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect + +nodes_attributes = { + # Parameter layers + **regular_op_with_shaped_data('placeholder_1', None, {'type': 'Parameter', 'op': 'Parameter'}), + **regular_op_with_shaped_data('placeholder_2', None, {'type': 'Parameter', 'op': 'Parameter'}), + **regular_op_with_shaped_data('placeholder_3', None, {'type': 'Parameter', 'op': 'Parameter'}), + + # Einsum layer + **regular_op_with_shaped_data('einsum', None, {'type': 'Einsum', 'op': 'Einsum'}), + + # Result layer + **result(), + + # Transpose layers + **regular_op_with_shaped_data('transpose_1', None, + {'type': 'Transpose', 'op': 'Transpose', 'need_shape_inference': True}), + **regular_op_with_shaped_data('transpose_3', None, + {'type': 'Transpose', 'op': 'Transpose', 'need_shape_inference': True}), + + # Const layers + **valued_const_with_data('axis_1_const', int64_array([0, 2, 3, 1])), + **valued_const_with_data('axis_3_const', int64_array([0, 4, 1, 2, 3])), +} + + +class LayoutChangeForEinsumTests(unittest.TestCase): + def test_layout_change_einsum(self): + graph = build_graph(nodes_attributes, + [*connect('placeholder_1', '0:einsum'), + *connect('placeholder_2', '1:einsum'), + *connect('placeholder_3', '2:einsum'), + *connect('einsum', 'output')], + { # this input stays as is since it is of a rank equal to 3 + 'placeholder_1_d': {'shape': np.array([2, 3, 5])}, + # [3, 5, 7, 8] - NHWC, [3, 8, 5, 7] - NCHW + # this input does not require additional transpose + # since the corresponding subscript can be adjusted + 'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])}, + # [3, 5, 10, 12] - NHWC, [3, 12, 5, 10] - NCHW + # the third input must be transposed to NHWC layout + # since ellipsis covers multiple dimensions in the end + # the corresponding subscript is not changed + 'placeholder_3_d': {'shape': np.array([3, 12, 8, 10])}, + # equation is still for NHWC layout + 'einsum': {'equation': "abc,bcde,bc...->ade..."}, + # [2, 7, 8, 10, 12] - NHWC, [2, 12, 7, 8, 10] - NCHW + # the output is in NCHW layout but its shape will be re-inferred since + # the output stays in NHWC layout due to ellipsis in the end + # and additional transpose to NCHW will be inserted + 'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])}, + }, nodes_with_edges_only=True) + graph.graph['fw'] = 'tf' + + graph_ref = build_graph(nodes_attributes, + [*connect('placeholder_3', '0:transpose_1'), + *connect('axis_1_const', '1:transpose_1'), + *connect('placeholder_1', '0:einsum'), + *connect('placeholder_2', '1:einsum'), + *connect('transpose_1', '2:einsum'), + *connect('einsum', '0:transpose_3'), + *connect('axis_3_const', '1:transpose_3'), + *connect('transpose_3', 'output')], + {'placeholder_1_d': {'shape': np.array([2, 3, 5])}, + 'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])}, + 'einsum': {'equation': "abc,becd,bc...->ade..."}, + 'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])} + }) + + LayoutChangeForEinsum().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/unit_tests/extensions/ops/einsum_test.py b/model-optimizer/unit_tests/extensions/ops/einsum_test.py new file mode 100644 index 00000000000..9407fe63d5f --- /dev/null +++ b/model-optimizer/unit_tests/extensions/ops/einsum_test.py @@ -0,0 +1,90 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +from generator import generator, generate + +from extensions.ops.einsum import Einsum +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Graph +from mo.graph.graph import Node +from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, result, connect + + +def create_einsum_graph(input_shapes: list, equation: str) -> Graph: + num_inputs = len(input_shapes) + assert num_inputs > 0, "Einsum node must have at least one input" + nodes = {} + edges = [] + for input_ind in range(num_inputs): + input_name = 'input' + str(input_ind) + parameter_op = regular_op_with_shaped_data(input_name, input_shapes[input_ind], + {'op': 'Parameter', 'type': 'Parameter'}) + nodes.update(parameter_op) + edges += connect(input_name, str(input_ind) + ":einsum_node") + einsum_op = regular_op_with_shaped_data('einsum_node', None, + {'op': 'Einsum', 'type': 'Einsum', 'equation': equation}) + nodes.update(einsum_op) + result_op = result('output') + nodes.update(result_op) + edges += connect('einsum_node', 'output') + + graph = build_graph(nodes, edges, nodes_with_edges_only=True) + return graph + + +@generator +class TestEinsum(unittest.TestCase): + @generate(*[ + # dot product + ([int64_array([10]), int64_array([10])], "i,i->", int64_array([])), + # matrix multiplication + ([int64_array([2, 3]), int64_array([3, 4])], "ab,bc->ac", int64_array([2, 4])), + # trace per batch + ([int64_array([2, 3, 3])], "kii->k", int64_array([2])), + # diagonal extraction + ([int64_array([6, 5, 5])], "kii->ki", int64_array([6, 5])), + # transpose + ([int64_array([1, 2, 3])], "ijk->kij", int64_array([3, 1, 2])), + # multiple matrix multiplication + ([int64_array([2, 5]), int64_array([5, 3, 6]), int64_array([5, 3])], "ab,bcd,bc->ca", int64_array([3, 2])), + # ellipsis for one operand + ([int64_array([5, 3, 4])], "a...->...", int64_array([3, 4])), + # ellipsis for multiple operands + ([int64_array([3, 5]), int64_array([1])], "a...,...->a...", int64_array([3, 5])), + # ellipsis with broadcasting + ([int64_array([9, 1, 4, 3]), int64_array([3, 11, 7, 1])], "a...b,b...->a...", int64_array([9, 11, 7, 4])), + # mixed case letters in equation + ([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])), + # mixed case letters and equation in implicit mode + ([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])), + ]) + def test_einsum(self, input_shapes, equation, ref_output_shape): + graph = create_einsum_graph(input_shapes, equation) + einsum_node = Node(graph, 'einsum_node') + Einsum.infer(einsum_node) + + # get the result + res_output_shape = graph.node['einsum_node_d']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'shape does not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + @generate(*[ + # incorrect subscript numbers or inputs + ([int64_array([3, 11]), int64_array([11, 4])], "ab,bc,cd->ac", None), + # invalid labels + ([int64_array([3, 11]), int64_array([11, 4])], "a$,Bc->ac", None), + # incompatible shapes + ([int64_array([3, 11]), int64_array([12, 4])], "ab,bc->ac", None), + # not broadcastable shapes + ([int64_array([11, 1, 4, 3]), int64_array([3, 11, 7, 5])], "a...b,b...->a...", None), + # missed ellipsis + ([int64_array([11, 1, 4, 3]), int64_array([3, 11, 7, 4])], "a...b,b...->a", None), + ]) + def test_invalid_cases(self, input_shapes, equation, ref_output_shape): + graph = create_einsum_graph(input_shapes, equation) + einsum_node = Node(graph, 'einsum_node') + self.assertRaises(AssertionError, Einsum.infer, einsum_node) diff --git a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py index 74217e26e53..2482a5b022b 100644 --- a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py +++ b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py @@ -14,7 +14,7 @@ nodes_attributes = {'data': {'kind': 'op'}, 'data_data': {'shape': None, 'value': None, 'kind': 'data'}, 'indices': {'kind': 'op'}, 'indices_data': {'shape': None, 'value': None, 'kind': 'data'}, - 'gathernd_node': {'op': 'ScatterNDUpdate', 'kind': 'op', 'batch_dims': 0}, + 'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0}, 'output': {'shape': None, 'value': None, 'kind': 'data'}} # graph 1 @@ -118,7 +118,7 @@ inputs_inv2 = {'data_data': {'shape': int64_array([10, 40, 20]), 'value': None}, inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value': None}, 'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}} -class TestScatterNDUpdate(unittest.TestCase): +class TestGatherNDUpdate(unittest.TestCase): def setUp(self): nodes_attributes['gathernd_node']['batch_dims'] = 0