added replacer for batch_dot

This commit is contained in:
yekruglov 2021-12-20 13:22:09 +03:00
parent fe4e714c76
commit 8c0e52f7dc
7 changed files with 105 additions and 5 deletions

View File

@ -273,6 +273,7 @@ openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
openvino/tools/mo/front/mxnet/arange_ext.py
openvino/tools/mo/front/mxnet/arange_replacer.py
openvino/tools/mo/front/mxnet/batch_dot_ext.py
openvino/tools/mo/front/mxnet/batch_dot_replacer.py
openvino/tools/mo/front/mxnet/block_grad_ext.py
openvino/tools/mo/front/mxnet/box_nms_ext.py
openvino/tools/mo/front/mxnet/cast_ext.py
@ -850,6 +851,7 @@ openvino/tools/mo/ops/assert_op.py
openvino/tools/mo/ops/assign.py
openvino/tools/mo/ops/aten.py
openvino/tools/mo/ops/axpy.py
openvino/tools/mo/ops/batch_dot.py
openvino/tools/mo/ops/BatchNormInference.py
openvino/tools/mo/ops/binarization.py
openvino/tools/mo/ops/BlockLSTM.py

View File

@ -56,8 +56,8 @@ class FullyConnectedDecomposer(FrontReplacementSubgraph):
node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose,
new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0]))
# input normalization for 4D Caffe
if graph.graph['fw'] in ['caffe']:
# input normalization for 4D Caffe and MxNet FullyConnected
if graph.graph['fw'] in ['caffe', 'mxnet']:
node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1]))

View File

@ -3,7 +3,7 @@
from openvino.tools.mo.front.extractor import FrontExtractorOp
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.MatMul import MatMul
from openvino.tools.mo.ops.batch_dot import MXNetBatchDot
class BatchDotExt(FrontExtractorOp):
@ -23,7 +23,7 @@ class BatchDotExt(FrontExtractorOp):
transpose_a = attrs.bool('transpose_a', False)
transpose_b = attrs.bool('transpose_b', False)
MatMul.update_node_stat(node, {
MXNetBatchDot.update_node_stat(node, {
'transpose_a': transpose_a,
'transpose_b': transpose_b
})

View File

@ -0,0 +1,27 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.MatMul_normalizer import FullyConnectedDecomposer
from openvino.tools.mo.front.common.replacement import FrontReplacementPattern
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.ops.MatMul import MatMul
class BatchDotReplacer(FrontReplacementPattern):
"""
Replaces MXNet batch_dot with MatMul. Should run after FullyConnectedDecomposer, because batch_dot does not need
in MatMul normalization.
"""
enabled = True
def run_after(self):
return [FullyConnectedDecomposer]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='batch_dot'):
node_name = node.soft_get('name', node.id)
matmul_node = MatMul(graph, dict(name='MatMul', transpose_a=node.transpose_a,
transpose_b=node.transpose_b)).create_node()
node.in_port(0).get_connection().set_destination(matmul_node.in_port(0))
node.in_port(1).get_connection().set_destination(matmul_node.in_port(1))
node.out_port(0).get_connection().set_source(matmul_node.out_port(0))
rename_nodes([(matmul_node, node_name)])

View File

@ -0,0 +1,27 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.ops.op import Op
class MXNetBatchDot(Op):
"""
MXNet operation which computes matrix multiplication of x and y similar to TF or ONNX MatMul operation.
Attributes:
transpose_a - if true then transpose the first input before multiplication
transpose_b - if true then transpose the second input before multiplication
"""
op = 'batch_dot'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': None,
'op': self.op,
'infer': None,
'in_ports_count': 2,
'out_ports_count': 1,
'transpose_a': False,
'transpose_b': False
}
super().__init__(graph, mandatory_props, attrs)

View File

@ -14,7 +14,7 @@ class DivSqrtDimOp(Op):
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': None,
'op': __class__.op,
'op': self.op,
'infer': None,
'in_ports_count': 1,
'out_ports_count': 1,

View File

@ -0,0 +1,44 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from openvino.tools.mo.front.mxnet.batch_dot_replacer import BatchDotReplacer
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op, result, connect_front
class BatchDotReplacerTest(unittest.TestCase):
def test_1(self):
graph = build_graph(
nodes_attrs={
**regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}),
**regular_op('const_input', {'op': 'Const', 'type': 'Const'}),
**regular_op('batch_dot', {'op': 'batch_dot', 'type': None, 'transpose_a': True, 'transpose_b': False}),
**result('result')
},
edges=[
*connect_front('input', '0:batch_dot'),
*connect_front('const_input', '1:batch_dot'),
*connect_front('batch_dot', 'result'),
]
)
ref_graph = build_graph(
nodes_attrs={
**regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}),
**regular_op('const_input', {'op': 'Const', 'type': 'Const'}),
**regular_op('mat_mul', {'op': 'MatMul', 'type': 'MatMul', 'transpose_a': True, 'transpose_b': False}),
**result('result')
},
edges=[
*connect_front('input', '0:mat_mul'),
*connect_front('const_input', '1:mat_mul'),
*connect_front('mat_mul', 'result'),
]
)
graph.stage = 'front'
BatchDotReplacer().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)