added replacer for batch_dot
This commit is contained in:
parent
fe4e714c76
commit
8c0e52f7dc
@ -273,6 +273,7 @@ openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
|
|||||||
openvino/tools/mo/front/mxnet/arange_ext.py
|
openvino/tools/mo/front/mxnet/arange_ext.py
|
||||||
openvino/tools/mo/front/mxnet/arange_replacer.py
|
openvino/tools/mo/front/mxnet/arange_replacer.py
|
||||||
openvino/tools/mo/front/mxnet/batch_dot_ext.py
|
openvino/tools/mo/front/mxnet/batch_dot_ext.py
|
||||||
|
openvino/tools/mo/front/mxnet/batch_dot_replacer.py
|
||||||
openvino/tools/mo/front/mxnet/block_grad_ext.py
|
openvino/tools/mo/front/mxnet/block_grad_ext.py
|
||||||
openvino/tools/mo/front/mxnet/box_nms_ext.py
|
openvino/tools/mo/front/mxnet/box_nms_ext.py
|
||||||
openvino/tools/mo/front/mxnet/cast_ext.py
|
openvino/tools/mo/front/mxnet/cast_ext.py
|
||||||
@ -850,6 +851,7 @@ openvino/tools/mo/ops/assert_op.py
|
|||||||
openvino/tools/mo/ops/assign.py
|
openvino/tools/mo/ops/assign.py
|
||||||
openvino/tools/mo/ops/aten.py
|
openvino/tools/mo/ops/aten.py
|
||||||
openvino/tools/mo/ops/axpy.py
|
openvino/tools/mo/ops/axpy.py
|
||||||
|
openvino/tools/mo/ops/batch_dot.py
|
||||||
openvino/tools/mo/ops/BatchNormInference.py
|
openvino/tools/mo/ops/BatchNormInference.py
|
||||||
openvino/tools/mo/ops/binarization.py
|
openvino/tools/mo/ops/binarization.py
|
||||||
openvino/tools/mo/ops/BlockLSTM.py
|
openvino/tools/mo/ops/BlockLSTM.py
|
||||||
|
@ -56,8 +56,8 @@ class FullyConnectedDecomposer(FrontReplacementSubgraph):
|
|||||||
node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose,
|
node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose,
|
||||||
new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0]))
|
new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0]))
|
||||||
|
|
||||||
# input normalization for 4D Caffe
|
# input normalization for 4D Caffe and MxNet FullyConnected
|
||||||
if graph.graph['fw'] in ['caffe']:
|
if graph.graph['fw'] in ['caffe', 'mxnet']:
|
||||||
node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
|
node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
|
||||||
new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1]))
|
new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1]))
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||||
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
||||||
from openvino.tools.mo.graph.graph import Node
|
from openvino.tools.mo.graph.graph import Node
|
||||||
from openvino.tools.mo.ops.MatMul import MatMul
|
from openvino.tools.mo.ops.batch_dot import MXNetBatchDot
|
||||||
|
|
||||||
|
|
||||||
class BatchDotExt(FrontExtractorOp):
|
class BatchDotExt(FrontExtractorOp):
|
||||||
@ -23,7 +23,7 @@ class BatchDotExt(FrontExtractorOp):
|
|||||||
transpose_a = attrs.bool('transpose_a', False)
|
transpose_a = attrs.bool('transpose_a', False)
|
||||||
transpose_b = attrs.bool('transpose_b', False)
|
transpose_b = attrs.bool('transpose_b', False)
|
||||||
|
|
||||||
MatMul.update_node_stat(node, {
|
MXNetBatchDot.update_node_stat(node, {
|
||||||
'transpose_a': transpose_a,
|
'transpose_a': transpose_a,
|
||||||
'transpose_b': transpose_b
|
'transpose_b': transpose_b
|
||||||
})
|
})
|
||||||
|
27
tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py
Normal file
27
tools/mo/openvino/tools/mo/front/mxnet/batch_dot_replacer.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from openvino.tools.mo.front.MatMul_normalizer import FullyConnectedDecomposer
|
||||||
|
from openvino.tools.mo.front.common.replacement import FrontReplacementPattern
|
||||||
|
from openvino.tools.mo.graph.graph import Graph, rename_nodes
|
||||||
|
from openvino.tools.mo.ops.MatMul import MatMul
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDotReplacer(FrontReplacementPattern):
|
||||||
|
"""
|
||||||
|
Replaces MXNet batch_dot with MatMul. Should run after FullyConnectedDecomposer, because batch_dot does not need
|
||||||
|
in MatMul normalization.
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
return [FullyConnectedDecomposer]
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
for node in graph.get_op_nodes(op='batch_dot'):
|
||||||
|
node_name = node.soft_get('name', node.id)
|
||||||
|
matmul_node = MatMul(graph, dict(name='MatMul', transpose_a=node.transpose_a,
|
||||||
|
transpose_b=node.transpose_b)).create_node()
|
||||||
|
node.in_port(0).get_connection().set_destination(matmul_node.in_port(0))
|
||||||
|
node.in_port(1).get_connection().set_destination(matmul_node.in_port(1))
|
||||||
|
node.out_port(0).get_connection().set_source(matmul_node.out_port(0))
|
||||||
|
rename_nodes([(matmul_node, node_name)])
|
27
tools/mo/openvino/tools/mo/ops/batch_dot.py
Normal file
27
tools/mo/openvino/tools/mo/ops/batch_dot.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from openvino.tools.mo.graph.graph import Graph
|
||||||
|
from openvino.tools.mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
|
class MXNetBatchDot(Op):
|
||||||
|
"""
|
||||||
|
MXNet operation which computes matrix multiplication of x and y similar to TF or ONNX MatMul operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
transpose_a - if true then transpose the first input before multiplication
|
||||||
|
transpose_b - if true then transpose the second input before multiplication
|
||||||
|
"""
|
||||||
|
op = 'batch_dot'
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
|
mandatory_props = {
|
||||||
|
'type': None,
|
||||||
|
'op': self.op,
|
||||||
|
'infer': None,
|
||||||
|
'in_ports_count': 2,
|
||||||
|
'out_ports_count': 1,
|
||||||
|
'transpose_a': False,
|
||||||
|
'transpose_b': False
|
||||||
|
}
|
||||||
|
super().__init__(graph, mandatory_props, attrs)
|
@ -14,7 +14,7 @@ class DivSqrtDimOp(Op):
|
|||||||
def __init__(self, graph: Graph, attrs: dict):
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
mandatory_props = {
|
mandatory_props = {
|
||||||
'type': None,
|
'type': None,
|
||||||
'op': __class__.op,
|
'op': self.op,
|
||||||
'infer': None,
|
'infer': None,
|
||||||
'in_ports_count': 1,
|
'in_ports_count': 1,
|
||||||
'out_ports_count': 1,
|
'out_ports_count': 1,
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from openvino.tools.mo.front.mxnet.batch_dot_replacer import BatchDotReplacer
|
||||||
|
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from unit_tests.utils.graph import build_graph, regular_op, result, connect_front
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDotReplacerTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_1(self):
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}),
|
||||||
|
**regular_op('const_input', {'op': 'Const', 'type': 'Const'}),
|
||||||
|
**regular_op('batch_dot', {'op': 'batch_dot', 'type': None, 'transpose_a': True, 'transpose_b': False}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect_front('input', '0:batch_dot'),
|
||||||
|
*connect_front('const_input', '1:batch_dot'),
|
||||||
|
*connect_front('batch_dot', 'result'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**regular_op('input', {'op': 'Parameter', 'type': 'Parameter'}),
|
||||||
|
**regular_op('const_input', {'op': 'Const', 'type': 'Const'}),
|
||||||
|
**regular_op('mat_mul', {'op': 'MatMul', 'type': 'MatMul', 'transpose_a': True, 'transpose_b': False}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect_front('input', '0:mat_mul'),
|
||||||
|
*connect_front('const_input', '1:mat_mul'),
|
||||||
|
*connect_front('mat_mul', 'result'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
graph.stage = 'front'
|
||||||
|
BatchDotReplacer().find_and_replace_pattern(graph)
|
||||||
|
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
Loading…
Reference in New Issue
Block a user