diff --git a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md index 782616375b1..e922e6917f9 100644 --- a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md +++ b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md @@ -53,6 +53,7 @@ | Symbol Name in MXNet\*| Limitations| | :----------| :----------| | _Plus | | +| _contrib_arange_like | | | _contrib_box_nms | | | _contrib_DeformableConvolution | | | _contrib_DeformablePSROIPooling | | diff --git a/tools/mo/automation/package_BOM.txt b/tools/mo/automation/package_BOM.txt index 904d0ad61e6..c962ed4bf6b 100644 --- a/tools/mo/automation/package_BOM.txt +++ b/tools/mo/automation/package_BOM.txt @@ -268,6 +268,8 @@ openvino/tools/mo/front/mxnet/activation.py openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py openvino/tools/mo/front/mxnet/arange_ext.py +openvino/tools/mo/front/mxnet/arange_like_ext.py +openvino/tools/mo/front/mxnet/arange_like_replacer.py openvino/tools/mo/front/mxnet/arange_replacer.py openvino/tools/mo/front/mxnet/batch_dot_ext.py openvino/tools/mo/front/mxnet/block_grad_ext.py @@ -847,6 +849,7 @@ openvino/tools/mo/ops/__init__.py openvino/tools/mo/ops/activation.py openvino/tools/mo/ops/activation_ops.py openvino/tools/mo/ops/adaptive_avg_pooling.py +openvino/tools/mo/ops/arange_like.py openvino/tools/mo/ops/argmax.py openvino/tools/mo/ops/argmin.py openvino/tools/mo/ops/assert_op.py diff --git a/tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py b/tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py new file mode 100644 index 00000000000..89ad2507d60 --- /dev/null +++ b/tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py @@ -0,0 +1,25 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +from openvino.tools.mo.front.extractor import FrontExtractorOp +from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs +from openvino.tools.mo.graph.graph import Node +from openvino.tools.mo.ops.arange_like import ArangeLikeOp + + +class ArangeLikeExt(FrontExtractorOp): + op = '_contrib_arange_like' + enabled = True + + @classmethod + def extract(cls, node: Node): + attrs = get_mxnet_layer_attrs(node.symbol_dict) + ArangeLikeOp.update_node_stat(node, { + 'start': attrs.float('start', 0), + 'repeat': attrs.int('repeat', 1), + 'step': attrs.float('step', 1), + 'axis': attrs.int('axis', None), + }) + return cls.enabled diff --git a/tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py b/tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py new file mode 100644 index 00000000000..16272b406bf --- /dev/null +++ b/tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py @@ -0,0 +1,148 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import numpy as np + +from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array +from openvino.tools.mo.front.common.replacement import FrontReplacementOp +from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs +from openvino.tools.mo.graph.graph import Graph, rename_nodes +from openvino.tools.mo.ops.Cast import Cast +from openvino.tools.mo.ops.ReduceOps import ReduceProd +from openvino.tools.mo.ops.elementwise import Add, Div, Mul +from openvino.tools.mo.ops.gather import Gather +from openvino.tools.mo.ops.range import Range +from openvino.tools.mo.ops.reshape import Reshape +from openvino.tools.mo.ops.shape import Shape +from openvino.tools.mo.ops.slice import Slice +from openvino.tools.mo.ops.squeeze import Squeeze +from openvino.tools.mo.ops.tile import Tile +from openvino.tools.mo.utils.error import Error + + +class ArangeLikeReplacer(FrontReplacementOp): + op = 'arange_like' + enabled = True + + def replace_sub_graph(self, graph: Graph, match: dict): + node = match['op'] + name = node.soft_get('name', node.id) + axis = node.axis + input_shape_node = Shape(graph, {'name': name + '/ShapeOf'}).create_node() + range_node = create_op_with_const_inputs(graph, Range, {0: mo_array(node.start), + 2: mo_array(node.step)}, {'name': name + '/Range'}) + node.in_port(0).get_connection().set_destination(input_shape_node.in_port(0)) + + if axis is not None: + ''' + Replace arange_like op to subgraph: + Shape - Gather - Range + ''' + gather_node = create_op_with_const_inputs(graph, Gather, {1: int64_array([axis]), + 2: int64_array(0)}, + {'name': name + '/Gather'}) + input_shape_node.out_port(0).connect(gather_node.in_port(0)) + gather_node.out_port(0).connect(range_node.in_port(1)) + node.out_port(0).get_connection().set_source(range_node.out_port(0)) + rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)]) + else: + r''' + Replace arange_like op to subgraph: + | + ShapeOf ----------- | + | | + ReduceProd | + | | + Range | + | | + Reshape ----------- | + | + ''' + + flattened_shape_node = create_op_with_const_inputs(graph, ReduceProd, {1: int64_array([0])}, + {'name': input_shape_node.name + '/ReduceProd', + 'keep_dims': True}) + reshape_backward_node = Reshape(graph, {'name': name + '/Reshape_backward'}).create_node() + + input_shape_node.out_port(0).connect(flattened_shape_node.in_port(0)) + flattened_shape_node.out_port(0).connect(range_node.in_port(1)) + range_node.out_port(0).connect(reshape_backward_node.in_port(0)) + input_shape_node.out_port(0).connect(reshape_backward_node.in_port(1)) + node.out_port(0).get_connection().set_source(reshape_backward_node.out_port(0)) + rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)]) + + if node.repeat != 1: + r""" + First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1. + Then repeats each value of the interval using Tile. After that we can get a longer interval + so we reduce it with Slice. + + Sub-graph after Range node will be look like + + Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice + + """ + + if node.repeat < 1: + raise Error("Unexpected value {} of the attribute 'repeat' for the node {}". format(node.repeat, name)) + + div_node = create_op_with_const_inputs(graph, Div, {1: int64_array([node.repeat])}, + {'name': name + '/Divide'}) + add_node = create_op_with_const_inputs(graph, Add, {1: int64_array([1])}, + {'name': div_node.name + '/Add'}) + cast_node = Cast(graph, {'name': name + '/ConvertToI64', 'dst_type': np.int64}).create_node() + + cast_node.out_port(0).connect(div_node.in_port(0)) + div_node.out_port(0).connect(add_node.in_port(0)) + range_node.in_port(1).get_connection().set_destination(cast_node.in_port(0)) + add_node.out_port(0).connect(range_node.in_port(1)) + + tile_forward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1, 1])}, + {'name': range_node.name + '/ForwardReshape'}) + tile = create_op_with_const_inputs(graph, Tile, {1: int64_array([1, node.repeat])}, + {'name': tile_forward_reshape.name + '/Tile'}) + tile_backward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1])}, + {'name': tile.name + '/BackwardReshape'}) + slice_node = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]), + 4: int64_array([1])}, + {'name': tile_backward_reshape.name + '/Slice'}) + + tile_forward_reshape.out_port(0).connect(tile.in_port(0)) + tile.out_port(0).connect(tile_backward_reshape.in_port(0)) + tile_backward_reshape.out_port(0).connect(slice_node.in_port(0)) + slice_node.in_port(2).connect(div_node.in_port(0).get_source()) + + range_node.out_port(0).get_connection().set_source(slice_node.out_port(0)) + range_node.out_port(0).connect(tile_forward_reshape.in_port(0)) + + if axis is not None: + rename_nodes([(range_node, name + '/Range'), (slice_node, name)]) + + # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so + # we have to correct the stop value for the Range node if step != 1 or start != 0 + if node.step != 1: + # If step attribute is not integer, we will generate an interval with a larger size and then reduce it + # using Slice + true_elements_count_port = range_node.in_port(1).get_source() + mul_value = np.ceil(node.step) if node.step > 0 else np.floor(node.step) + stop_value = create_op_with_const_inputs(graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))}, + op_attrs={'name': range_node.name + '/Stop'}) + range_node.in_port(1).get_connection().insert_node(stop_value) + + slice_range_values = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]), + 4: int64_array([1])}, + {'name': range_node.name + '/Slice'}) + slice_range_values.in_port(2).connect(true_elements_count_port) + range_node.out_port(0).get_connection().insert_node(slice_range_values) + + if axis is not None and node.repeat == 1: + rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)]) + + if node.start != 0: + correct_stop_value = create_op_with_const_inputs(graph, Add, port_value_dict={1: mo_array(node.start)}, + op_attrs={'name': range_node.name + '/Correct_Stop'}) + range_node.in_port(1).get_connection().insert_node(correct_stop_value) + + # Range node supports only scalar inputs + squeeze_node = create_op_with_const_inputs(graph, Squeeze, port_value_dict={1: int64_array(0)}, + op_attrs={"name": range_node.name + '/Stop/Squeeze'}) + range_node.in_port(1).get_connection().insert_node(squeeze_node) diff --git a/tools/mo/openvino/tools/mo/front/mxnet/div_sqrt_dim.py b/tools/mo/openvino/tools/mo/front/mxnet/div_sqrt_dim.py index 782c38eb476..db867b1c4b2 100644 --- a/tools/mo/openvino/tools/mo/front/mxnet/div_sqrt_dim.py +++ b/tools/mo/openvino/tools/mo/front/mxnet/div_sqrt_dim.py @@ -39,15 +39,12 @@ class DivSqrtDim(FrontReplacementOp): # Due to specification, Power must have inputs with the same data type. convert_pow_input = Cast(graph, dict(dst_type=np.float32, name=shape_values_node.name + '/ConvertToFP32')).create_node() - convert_pow_output = ConvertLike(graph, dict(name=pow_node.name + 'ConvertLike')).create_node() div_node = Div(graph, dict(name="Div")).create_node() shape_values_node.out_port(0).connect(convert_pow_input.in_port(0)) convert_pow_input.out_port(0).connect(pow_node.in_port(0)) div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0)) - pow_node.out_port(0).connect(convert_pow_output.in_port(0)) - convert_pow_output.in_port(1).connect(data_out_port) - div_node.in_port(1).connect(convert_pow_output.out_port(0)) + div_node.in_port(1).connect(pow_node.out_port(0)) div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0)) rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)]) diff --git a/tools/mo/openvino/tools/mo/ops/arange_like.py b/tools/mo/openvino/tools/mo/ops/arange_like.py new file mode 100644 index 00000000000..fd78a6da8ff --- /dev/null +++ b/tools/mo/openvino/tools/mo/ops/arange_like.py @@ -0,0 +1,29 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.tools.mo.graph.graph import Graph +from openvino.tools.mo.ops.op import Op + + +class ArangeLikeOp(Op): + """ + MXNet operation which returns a sequence of numbers. If axis attribute is None, the output has the + same shape as the input. Otherwise, the output is a 1D array with size of the specified axis. + + Attributes: + start - Start of interval + step - Spacing between values + repeat - The repeating time of all elements. Now we can support only default value (= 1) + axis - Arange elements according to the size of a certain axis of input array. Defualt value is None + + """ + op = 'arange_like' + + def __init__(self, graph: Graph, attrs: dict): + mandatory_props = { + 'type': None, + 'op': self.op, + 'infer': None, + 'in_ports_count': 1, + 'out_ports_count': 1, + } + super().__init__(graph, mandatory_props, attrs) \ No newline at end of file diff --git a/tools/mo/openvino/tools/mo/ops/gather.py b/tools/mo/openvino/tools/mo/ops/gather.py index 7e69de7de25..c89d7675886 100644 --- a/tools/mo/openvino/tools/mo/ops/gather.py +++ b/tools/mo/openvino/tools/mo/ops/gather.py @@ -90,6 +90,7 @@ class Gather(Op): data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None and is_fully_defined(indices_value): + indices_value = int64_array(indices_value) if batch_dims == 0: node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis)) else: diff --git a/tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py b/tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py new file mode 100644 index 00000000000..c93d254824f --- /dev/null +++ b/tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py @@ -0,0 +1,238 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import unittest +from openvino.tools.mo.front.common.partial_infer.utils import int64_array +from openvino.tools.mo.front.mxnet.arange_like_replacer import ArangeLikeReplacer +from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs +from unit_tests.utils.graph import build_graph, shaped_parameter, regular_op_with_empty_data, result, connect, \ + shaped_const_with_data, connect_data + + +class ArangeLikeReplacerTest(unittest.TestCase): + def test_axis_not_none_start_0(self): + graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1, + 'start': 0, 'step': 1}), + **result('result') + }, + edges=[ + *connect('input', 'arange_like'), + *connect('arange_like', 'result') + ] + ) + ref_graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}), + **shaped_const_with_data('gather_axis', None), + **shaped_const_with_data('gather_indices', None), + **regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}), + **shaped_const_with_data('range_start', None), + **shaped_const_with_data('range_step', None), + **shaped_const_with_data('squeeze_const', None), + **regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}), + **regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}), + **result('result') + }, + edges=[ + *connect('input', 'shape_of'), + *connect('shape_of', '0:gather'), + *connect('gather_axis', '1:gather'), + *connect('gather_indices', '2:gather'), + *connect('range_start', '0:range'), + *connect('gather', '0:squeeze'), + *connect('squeeze_const', '1:squeeze'), + *connect('squeeze', '1:range'), + *connect('range_step', '2:range'), + *connect('range', 'result') + ], + update_attributes={ + 'gather_axis': {'value': 3}, + 'gather_indices': {'value': 0}, + 'range_start': {'value': 0}, + 'range_step': {'value': 1} + } + ) + ArangeLikeReplacer().find_and_replace_pattern(graph) + flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_axis_not_none_start_1_step_2(self): + graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1, + 'start': 1, 'step': 2}), + **result('result') + }, + edges=[ + *connect('input', 'arange_like'), + *connect('arange_like', 'result') + ] + ) + ref_graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}), + **shaped_const_with_data('gather_axis', None), + **shaped_const_with_data('gather_indices', None), + **regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}), + **regular_op_with_empty_data('mul', {'op': 'Mul', 'type': 'Multiply'}), + **shaped_const_with_data('mul_const', None), + **shaped_const_with_data('range_start', None), + **shaped_const_with_data('range_step', None), + **shaped_const_with_data('add_const', None), + **regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}), + **shaped_const_with_data('squeeze_const', None), + **regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}), + **regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}), + **regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}), + **shaped_const_with_data('slice_start', None), + **shaped_const_with_data('slice_axes', None), + **shaped_const_with_data('slice_step', None), + **result('result') + }, + edges=[ + *connect('input', 'shape_of'), + *connect('shape_of', '0:gather'), + *connect('gather_axis', '1:gather'), + *connect('gather_indices', '2:gather'), + *connect('range_start', '0:range'), + *connect('gather', '0:mul'), + *connect('mul_const', '1:mul'), + *connect('mul', '0:add'), + *connect('add_const', '1:add'), + *connect('squeeze_const', '1:squeeze'), + *connect('add', '0:squeeze'), + *connect('squeeze', '1:range'), + *connect('range_step', '2:range'), + *connect('range', '0:slice'), + *connect('slice_start', '1:slice'), + *connect_data('gather', '2:slice'), + *connect('slice_axes', '3:slice'), + *connect('slice_step', '4:slice'), + *connect('slice', 'result') + ], + update_attributes={ + 'gather_axis': {'value': 3}, + 'gather_indices': {'value': 0}, + 'range_start': {'value': 1}, + 'range_step': {'value': 2}, + 'add_const': {'value': 1}, + 'mul_const': {'value': 2}, + 'slice_start': {'value': int64_array([0])}, + 'slice_axes': {'value': int64_array([0])}, + 'slice_step': {'value': int64_array([1])}, + } + ) + ArangeLikeReplacer().find_and_replace_pattern(graph) + flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_axis_none_start_0(self): + graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None, + 'repeat': 1, 'start': 0, 'step': 1}), + **result('result') + }, + edges=[ + *connect('input', 'arange_like'), + *connect('arange_like', 'result') + ] + ) + ref_graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}), + **regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}), + **shaped_const_with_data('reduce_prod_const', None), + **shaped_const_with_data('squeeze_const', None), + **regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}), + **shaped_const_with_data('range_start', None), + **shaped_const_with_data('range_step', None), + **regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}), + **regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}), + **result('result') + }, + edges=[ + *connect('input', 'shape_of'), + *connect('shape_of', '0:reduce_prod'), + *connect('reduce_prod_const', '1:reduce_prod'), + *connect('squeeze_const', '1:squeeze'), + *connect('reduce_prod', '0:squeeze'), + *connect('range_start', '0:range'), + *connect('range_step', '2:range'), + *connect('squeeze', '1:range'), + *connect('range', '0:reshape_backward'), + *connect_data('shape_of', '1:reshape_backward'), + *connect('reshape_backward', 'result') + ], + update_attributes={ + 'range_start': {'value': 0}, + 'range_step': {'value': 1}, + 'reduce_prod_const': {'value': int64_array([0])} + } + ) + + ArangeLikeReplacer().find_and_replace_pattern(graph) + flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_axis_none_start_1(self): + graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None, + 'repeat': 1, 'start': 1, 'step': 1}), + **result('result') + }, + edges=[ + *connect('input', 'arange_like'), + *connect('arange_like', 'result') + ] + ) + ref_graph = build_graph( + nodes_attrs={ + **shaped_parameter('input', int64_array([1, 3, 5, 5])), + **regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}), + **regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}), + **shaped_const_with_data('reduce_prod_const', None), + **shaped_const_with_data('squeeze_const', None), + **regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}), + **shaped_const_with_data('add_const', None), + **regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}), + **shaped_const_with_data('range_start', None), + **shaped_const_with_data('range_step', None), + **regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}), + **regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}), + **result('result') + }, + edges=[ + *connect('input', 'shape_of'), + *connect('shape_of', '0:reduce_prod'), + *connect('reduce_prod_const', '1:reduce_prod'), + *connect('squeeze_const', '1:squeeze'), + *connect('add_const', '1:add'), + *connect('reduce_prod', '0:add'), + *connect('add', '0:squeeze'), + *connect('range_start', '0:range'), + *connect('range_step', '2:range'), + *connect('squeeze', '1:range'), + *connect('range', '0:reshape_backward'), + *connect_data('shape_of', '1:reshape_backward'), + *connect('reshape_backward', 'result') + ], + update_attributes={ + 'range_start': {'value': 1}, + 'range_step': {'value': 1}, + 'add_const': {'value': 1}, + 'reduce_prod_const': {'value': int64_array([0])} + } + ) + ArangeLikeReplacer().find_and_replace_pattern(graph) + flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/tools/mo/unit_tests/mo/front/mxnet/div_sqrt_dim_test.py b/tools/mo/unit_tests/mo/front/mxnet/div_sqrt_dim_test.py index 77ea6a8610f..67fed218bc7 100644 --- a/tools/mo/unit_tests/mo/front/mxnet/div_sqrt_dim_test.py +++ b/tools/mo/unit_tests/mo/front/mxnet/div_sqrt_dim_test.py @@ -36,7 +36,6 @@ class DivSqrtDimTest(unittest.TestCase): **regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}), **regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}), **regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}), - **regular_op_with_empty_data('z_convert_like', {'op': 'ConvertLike', 'type': 'ConvertLike'}), **regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}), **result('result') }, @@ -48,9 +47,7 @@ class DivSqrtDimTest(unittest.TestCase): *connect('gather_indices', '2:gather'), *connect('gather', 'cast'), *connect('cast', 'power'), - *connect('power', '0:z_convert_like'), - *connect_front('input_d', '1:z_convert_like'), - *connect('z_convert_like', '1:div'), + *connect('power', '1:div'), *connect('div', 'result') ], )