added replacer for batch_dot
This commit is contained in:
parent
fe4e714c76
commit
8c0e52f7dc
@ -273,6 +273,7 @@ openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
|
||||
openvino/tools/mo/front/mxnet/arange_ext.py
|
||||
openvino/tools/mo/front/mxnet/arange_replacer.py
|
||||
openvino/tools/mo/front/mxnet/batch_dot_ext.py
|
||||
openvino/tools/mo/front/mxnet/batch_dot_replacer.py
|
||||
openvino/tools/mo/front/mxnet/block_grad_ext.py
|
||||
openvino/tools/mo/front/mxnet/box_nms_ext.py
|
||||
openvino/tools/mo/front/mxnet/cast_ext.py
|
||||
@ -850,6 +851,7 @@ openvino/tools/mo/ops/assert_op.py
|
||||
openvino/tools/mo/ops/assign.py
|
||||
openvino/tools/mo/ops/aten.py
|
||||
openvino/tools/mo/ops/axpy.py
|
||||
openvino/tools/mo/ops/batch_dot.py
|
||||
openvino/tools/mo/ops/BatchNormInference.py
|
||||
openvino/tools/mo/ops/binarization.py
|
||||
openvino/tools/mo/ops/BlockLSTM.py
|
||||
|
@ -56,8 +56,8 @@ class FullyConnectedDecomposer(FrontReplacementSubgraph):
|
||||
node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose,
|
||||
new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0]))
|
||||
|
||||
# input normalization for 4D Caffe
|
||||
if graph.graph['fw'] in ['caffe']:
|
||||
# input normalization for 4D Caffe and MxNet FullyConnected
|
||||
if graph.graph['fw'] in ['caffe', 'mxnet']:
|
||||
node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
|
||||
new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1]))
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.MatMul import MatMul
|
||||
from openvino.tools.mo.ops.batch_dot import MXNetBatchDot
|
||||
|
||||
|
||||
class BatchDotExt(FrontExtractorOp):
|
||||
@ -23,7 +23,7 @@ class BatchDotExt(FrontExtractorOp):
|
||||
transpose_a = attrs.bool('transpose_a', False)
|
||||
transpose_b = attrs.bool('transpose_b', False)
|
||||
|
||||
MatMul.update_node_stat(node, {
|
||||
MXNetBatchDot.update_node_stat(node, {
|
||||
'transpose_a': transpose_a,
|
||||
'transpose_b': transpose_b
|
||||
})
|
||||
|
27
tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py
Normal file
27
tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.tools.mo.front.MatMul_normalizer import FullyConnectedDecomposer
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementPattern
|
||||
from openvino.tools.mo.graph.graph import Graph, rename_nodes
|
||||
from openvino.tools.mo.ops.MatMul import MatMul
|
||||
|
||||
|
||||
class BatchDotReplacer(FrontReplacementPattern):
|
||||
"""
|
||||
Replaces MXNet batch_dot with MatMul. Should run after FullyConnectedDecomposer, because batch_dot does not need
|
||||
in MatMul normalization.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
return [FullyConnectedDecomposer]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='batch_dot'):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
matmul_node = MatMul(graph, dict(name='MatMul', transpose_a=node.transpose_a,
|
||||
transpose_b=node.transpose_b)).create_node()
|
||||
node.in_port(0).get_connection().set_destination(matmul_node.in_port(0))
|
||||
node.in_port(1).get_connection().set_destination(matmul_node.in_port(1))
|
||||
node.out_port(0).get_connection().set_source(matmul_node.out_port(0))
|
||||
rename_nodes([(matmul_node, node_name)])
|
27
tools/mo/openvino/tools/mo/ops/batch_dot.py
Normal file
27
tools/mo/openvino/tools/mo/ops/batch_dot.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
|
||||
|
||||
class MXNetBatchDot(Op):
|
||||
"""
|
||||
MXNet operation which computes matrix multiplication of x and y similar to TF or ONNX MatMul operation.
|
||||
|
||||
Attributes:
|
||||
transpose_a - if true then transpose the first input before multiplication
|
||||
transpose_b - if true then transpose the second input before multiplication
|
||||
"""
|
||||
op = 'batch_dot'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': None,
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
'transpose_a': False,
|
||||
'transpose_b': False
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
@ -14,7 +14,7 @@ class DivSqrtDimOp(Op):
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': None,
|
||||
'op': __class__.op,
|
||||
'op': self.op,
|
||||
'infer': None,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
|
@ -0,0 +1,44 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from openvino.tools.mo.front.mxnet.batch_dot_replacer import BatchDotReplacer
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op, result, connect_front
|
||||
|
||||
|
||||
class BatchDotReplacerTest(unittest.TestCase):
|
||||
|
||||
def test_1(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}),
|
||||
**regular_op('const_input', {'op': 'Const', 'type': 'Const'}),
|
||||
**regular_op('batch_dot', {'op': 'batch_dot', 'type': None, 'transpose_a': True, 'transpose_b': False}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect_front('input', '0:batch_dot'),
|
||||
*connect_front('const_input', '1:batch_dot'),
|
||||
*connect_front('batch_dot', 'result'),
|
||||
]
|
||||
)
|
||||
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}),
|
||||
**regular_op('const_input', {'op': 'Const', 'type': 'Const'}),
|
||||
**regular_op('mat_mul', {'op': 'MatMul', 'type': 'MatMul', 'transpose_a': True, 'transpose_b': False}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect_front('input', '0:mat_mul'),
|
||||
*connect_front('const_input', '1:mat_mul'),
|
||||
*connect_front('mat_mul', 'result'),
|
||||
]
|
||||
)
|
||||
graph.stage = 'front'
|
||||
BatchDotReplacer().find_and_replace_pattern(graph)
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
Loading…
Reference in New Issue
Block a user