From 8c0e52f7dca4ee77221ed41c74521b8e10f3d4f1 Mon Sep 17 00:00:00 2001 From: yekruglov Date: Mon, 20 Dec 2021 13:22:09 +0300 Subject: [PATCH] added replacer for batch_dot --- tools/mo/automation/package_BOM.txt | 2 + .../tools/mo/front/MatMul_normalizer.py | 4 +- .../tools/mo/front/mxnet/batch_dot_ext.py | 4 +- .../mo/front/mxnet/batch_dot_replacer.py | 27 ++++++++++++ tools/mo/openvino/tools/mo/ops/batch_dot.py | 27 ++++++++++++ .../mo/openvino/tools/mo/ops/div_sqrt_dim.py | 2 +- .../mo/front/mxnet/batch_dot_replacer_test.py | 44 +++++++++++++++++++ 7 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py create mode 100644 tools/mo/openvino/tools/mo/ops/batch_dot.py create mode 100644 tools/mo/unit_tests/mo/front/mxnet/batch_dot_replacer_test.py diff --git a/tools/mo/automation/package_BOM.txt b/tools/mo/automation/package_BOM.txt index 4774b99b890..15f2c08bc77 100644 --- a/tools/mo/automation/package_BOM.txt +++ b/tools/mo/automation/package_BOM.txt @@ -273,6 +273,7 @@ openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py openvino/tools/mo/front/mxnet/arange_ext.py openvino/tools/mo/front/mxnet/arange_replacer.py openvino/tools/mo/front/mxnet/batch_dot_ext.py +openvino/tools/mo/front/mxnet/batch_dot_replacer.py openvino/tools/mo/front/mxnet/block_grad_ext.py openvino/tools/mo/front/mxnet/box_nms_ext.py openvino/tools/mo/front/mxnet/cast_ext.py @@ -850,6 +851,7 @@ openvino/tools/mo/ops/assert_op.py openvino/tools/mo/ops/assign.py openvino/tools/mo/ops/aten.py openvino/tools/mo/ops/axpy.py +openvino/tools/mo/ops/batch_dot.py openvino/tools/mo/ops/BatchNormInference.py openvino/tools/mo/ops/binarization.py openvino/tools/mo/ops/BlockLSTM.py diff --git a/tools/mo/openvino/tools/mo/front/MatMul_normalizer.py b/tools/mo/openvino/tools/mo/front/MatMul_normalizer.py index b4260423f99..bd42c7eafca 100644 --- a/tools/mo/openvino/tools/mo/front/MatMul_normalizer.py +++ b/tools/mo/openvino/tools/mo/front/MatMul_normalizer.py @@ -56,8 +56,8 @@ class FullyConnectedDecomposer(FrontReplacementSubgraph): node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose, new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0])) - # input normalization for 4D Caffe - if graph.graph['fw'] in ['caffe']: + # input normalization for 4D Caffe and MxNet FullyConnected + if graph.graph['fw'] in ['caffe', 'mxnet']: node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape, new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1])) diff --git a/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_ext.py b/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_ext.py index f36da047935..0b975fba0ad 100644 --- a/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_ext.py +++ b/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_ext.py @@ -3,7 +3,7 @@ from openvino.tools.mo.front.extractor import FrontExtractorOp from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs from openvino.tools.mo.graph.graph import Node -from openvino.tools.mo.ops.MatMul import MatMul +from openvino.tools.mo.ops.batch_dot import MXNetBatchDot class BatchDotExt(FrontExtractorOp): @@ -23,7 +23,7 @@ class BatchDotExt(FrontExtractorOp): transpose_a = attrs.bool('transpose_a', False) transpose_b = attrs.bool('transpose_b', False) - MatMul.update_node_stat(node, { + MXNetBatchDot.update_node_stat(node, { 'transpose_a': transpose_a, 'transpose_b': transpose_b }) diff --git a/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py b/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py new file mode 100644 index 00000000000..b56c3b6bd1c --- /dev/null +++ b/tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py @@ -0,0 +1,27 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.tools.mo.front.MatMul_normalizer import FullyConnectedDecomposer +from openvino.tools.mo.front.common.replacement import FrontReplacementPattern +from openvino.tools.mo.graph.graph import Graph, rename_nodes +from openvino.tools.mo.ops.MatMul import MatMul + + +class BatchDotReplacer(FrontReplacementPattern): + """ + Replaces MXNet batch_dot with MatMul. Should run after FullyConnectedDecomposer, because batch_dot does not need + in MatMul normalization. + """ + enabled = True + + def run_after(self): + return [FullyConnectedDecomposer] + + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(op='batch_dot'): + node_name = node.soft_get('name', node.id) + matmul_node = MatMul(graph, dict(name='MatMul', transpose_a=node.transpose_a, + transpose_b=node.transpose_b)).create_node() + node.in_port(0).get_connection().set_destination(matmul_node.in_port(0)) + node.in_port(1).get_connection().set_destination(matmul_node.in_port(1)) + node.out_port(0).get_connection().set_source(matmul_node.out_port(0)) + rename_nodes([(matmul_node, node_name)]) diff --git a/tools/mo/openvino/tools/mo/ops/batch_dot.py b/tools/mo/openvino/tools/mo/ops/batch_dot.py new file mode 100644 index 00000000000..b690d1e5787 --- /dev/null +++ b/tools/mo/openvino/tools/mo/ops/batch_dot.py @@ -0,0 +1,27 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.tools.mo.graph.graph import Graph +from openvino.tools.mo.ops.op import Op + + +class MXNetBatchDot(Op): + """ + MXNet operation which computes matrix multiplication of x and y similar to TF or ONNX MatMul operation. + + Attributes: + transpose_a - if true then transpose the first input before multiplication + transpose_b - if true then transpose the second input before multiplication + """ + op = 'batch_dot' + + def __init__(self, graph: Graph, attrs: dict): + mandatory_props = { + 'type': None, + 'op': self.op, + 'infer': None, + 'in_ports_count': 2, + 'out_ports_count': 1, + 'transpose_a': False, + 'transpose_b': False + } + super().__init__(graph, mandatory_props, attrs) diff --git a/tools/mo/openvino/tools/mo/ops/div_sqrt_dim.py b/tools/mo/openvino/tools/mo/ops/div_sqrt_dim.py index 2f543fa09eb..e475c0d1193 100644 --- a/tools/mo/openvino/tools/mo/ops/div_sqrt_dim.py +++ b/tools/mo/openvino/tools/mo/ops/div_sqrt_dim.py @@ -14,7 +14,7 @@ class DivSqrtDimOp(Op): def __init__(self, graph: Graph, attrs: dict): mandatory_props = { 'type': None, - 'op': __class__.op, + 'op': self.op, 'infer': None, 'in_ports_count': 1, 'out_ports_count': 1, diff --git a/tools/mo/unit_tests/mo/front/mxnet/batch_dot_replacer_test.py b/tools/mo/unit_tests/mo/front/mxnet/batch_dot_replacer_test.py new file mode 100644 index 00000000000..39b602cda9e --- /dev/null +++ b/tools/mo/unit_tests/mo/front/mxnet/batch_dot_replacer_test.py @@ -0,0 +1,44 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from openvino.tools.mo.front.mxnet.batch_dot_replacer import BatchDotReplacer +from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs +from unit_tests.utils.graph import build_graph, regular_op, result, connect_front + + +class BatchDotReplacerTest(unittest.TestCase): + + def test_1(self): + graph = build_graph( + nodes_attrs={ + **regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}), + **regular_op('const_input', {'op': 'Const', 'type': 'Const'}), + **regular_op('batch_dot', {'op': 'batch_dot', 'type': None, 'transpose_a': True, 'transpose_b': False}), + **result('result') + }, + edges=[ + *connect_front('input', '0:batch_dot'), + *connect_front('const_input', '1:batch_dot'), + *connect_front('batch_dot', 'result'), + ] + ) + + ref_graph = build_graph( + nodes_attrs={ + **regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}), + **regular_op('const_input', {'op': 'Const', 'type': 'Const'}), + **regular_op('mat_mul', {'op': 'MatMul', 'type': 'MatMul', 'transpose_a': True, 'transpose_b': False}), + **result('result') + }, + edges=[ + *connect_front('input', '0:mat_mul'), + *connect_front('const_input', '1:mat_mul'), + *connect_front('mat_mul', 'result'), + ] + ) + graph.stage = 'front' + BatchDotReplacer().find_and_replace_pattern(graph) + flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True) + self.assertTrue(flag, resp)