Support Dilated convolution in dynamic case (#14615)
* Add Mod evaluate method * Fix mo transformation for dynamic case * Fix build
This commit is contained in:
parent
bc685ac8a0
commit
7df9031411
@ -27,6 +27,9 @@ public:
|
||||
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
};
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "ngraph/op/mod.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/runtime/reference/mod.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -21,3 +22,81 @@ shared_ptr<Node> op::v1::Mod::clone_with_new_inputs(const OutputVector& new_args
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<Mod>(new_args.at(0), new_args.at(1), this->get_autob());
|
||||
}
|
||||
|
||||
namespace mod_op {
|
||||
namespace {
|
||||
template <typename T>
|
||||
bool evaluate(const ov::Tensor& arg0,
|
||||
const ov::Tensor& arg1,
|
||||
const ov::Tensor& out,
|
||||
const op::AutoBroadcastSpec& broadcast_spec) {
|
||||
runtime::reference::mod(arg0.data<T>(),
|
||||
arg1.data<T>(),
|
||||
out.data<T>(),
|
||||
arg0.get_shape(),
|
||||
arg1.get_shape(),
|
||||
broadcast_spec);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_mod(const ov::Tensor& arg0,
|
||||
const ov::Tensor& arg1,
|
||||
const ov::Tensor& out,
|
||||
const op::AutoBroadcastSpec& broadcast_spec) {
|
||||
bool rc = true;
|
||||
switch (arg0.get_element_type()) {
|
||||
case ov::element::Type_t::i8: {
|
||||
rc = evaluate<int8_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::i16: {
|
||||
rc = evaluate<int16_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::i32: {
|
||||
rc = evaluate<int32_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::i64: {
|
||||
rc = evaluate<int64_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::u8: {
|
||||
rc = evaluate<uint8_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::u16: {
|
||||
rc = evaluate<uint16_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::u32: {
|
||||
rc = evaluate<uint32_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
case ov::element::Type_t::u64: {
|
||||
rc = evaluate<uint64_t>(arg0, arg1, out, broadcast_spec);
|
||||
} break;
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace mod_op
|
||||
|
||||
bool op::v1::Mod::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_Mod_evaluate);
|
||||
return mod_op::evaluate_mod(inputs[0], inputs[1], outputs[0], get_autob());
|
||||
}
|
||||
|
||||
bool op::v1::Mod::has_evaluate() const {
|
||||
OV_OP_SCOPE(v1_Mod_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::i8:
|
||||
case ngraph::element::i16:
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u8:
|
||||
case ngraph::element::u16:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -30,29 +30,37 @@ class DilatedConvolutionConverter(MiddleReplacementPattern):
|
||||
('conv', dict(kind='op', op=lambda value: value in ['Conv2D', 'DepthwiseConv2dNative', 'Conv3D'])),
|
||||
('space_to_batch', dict(kind='op', op='SpaceToBatch')),
|
||||
('batch_to_space', dict(kind='op', op='BatchToSpace')),
|
||||
('stb_pad_begin', dict(kind='op', op='Const')),
|
||||
('stb_pad_end', dict(kind='op', op='Const')),
|
||||
('bts_crop_begin', dict(kind='op', op='Const')),
|
||||
('bts_crop_end', dict(kind='op', op='Const')),
|
||||
('input', dict(kind='data')),
|
||||
('output', dict(kind='data')),
|
||||
('conv_output', dict(kind='data')),
|
||||
('stb_output', dict(kind='data')),
|
||||
('stb_bs', dict(kind='data')),
|
||||
('stb_pad_begin', dict(kind='data')),
|
||||
('stb_pad_end', dict(kind='data')),
|
||||
('stb_pad_begin_d', dict(kind='data')),
|
||||
('stb_pad_end_d', dict(kind='data')),
|
||||
('bts_bs', dict(kind='data')),
|
||||
('bts_crop_begin', dict(kind='data')),
|
||||
('bts_crop_end', dict(kind='data'))
|
||||
('bts_crop_begin_d', dict(kind='data')),
|
||||
('bts_crop_end_d', dict(kind='data'))
|
||||
],
|
||||
edges=[
|
||||
('input', 'space_to_batch', {'in': 0}),
|
||||
('stb_bs', 'space_to_batch', {'in': 1}),
|
||||
('stb_pad_begin', 'space_to_batch', {'in': 2}),
|
||||
('stb_pad_end', 'space_to_batch', {'in': 3}),
|
||||
('stb_pad_begin', 'stb_pad_begin_d', {'in': 0}),
|
||||
('stb_pad_begin_d', 'space_to_batch', {'in': 2}),
|
||||
('stb_pad_end', 'stb_pad_end_d', {'in': 0}),
|
||||
('stb_pad_end_d', 'space_to_batch', {'in': 3}),
|
||||
('space_to_batch', 'stb_output', {'out': 0}),
|
||||
('stb_output', 'conv', {'in': 0}),
|
||||
('conv', 'conv_output', {'out': 0}),
|
||||
('conv_output', 'batch_to_space', {'in': 0}),
|
||||
('bts_bs', 'batch_to_space', {'in': 1}),
|
||||
('bts_crop_begin', 'batch_to_space', {'in': 2}),
|
||||
('bts_crop_end', 'batch_to_space', {'in': 3}),
|
||||
('bts_crop_begin', 'bts_crop_begin_d', {'in': 0}),
|
||||
('bts_crop_begin_d', 'batch_to_space', {'in': 2}),
|
||||
('bts_crop_end', 'bts_crop_end_d', {'in': 0}),
|
||||
('bts_crop_end_d', 'batch_to_space', {'in': 3}),
|
||||
('batch_to_space', 'output', {'out': 0}),
|
||||
])
|
||||
|
||||
@ -63,31 +71,17 @@ class DilatedConvolutionConverter(MiddleReplacementPattern):
|
||||
|
||||
block_size = match['stb_bs']
|
||||
|
||||
input = match['input']
|
||||
output = match['output']
|
||||
stb_out = match['stb_output']
|
||||
conv_out = match['conv_output']
|
||||
|
||||
in_edge_attrs = graph.get_edge_data(input.id, stb.id)[0]
|
||||
out_edge_attrs = graph.get_edge_data(bts.id, output.id)[0]
|
||||
|
||||
graph.remove_edge(input.id, stb.id)
|
||||
graph.remove_edge(stb_out.id, conv.id)
|
||||
graph.remove_edge(conv.id, conv_out.id)
|
||||
graph.remove_edge(bts.id, output.id)
|
||||
conv.in_port(0).disconnect()
|
||||
stb.in_port(0).get_connection().set_destination(conv.in_port(0))
|
||||
bts.out_port(0).get_connection().set_source(conv.out_port(0))
|
||||
|
||||
conv.dilation[conv.spatial_dims] = block_size.value[conv.spatial_dims]
|
||||
|
||||
pad_begin = match['stb_pad_begin'].value - match['bts_crop_begin'].value
|
||||
pad_end = match['stb_pad_end'].value - match['bts_crop_end'].value
|
||||
pad_begin = match['stb_pad_begin_d'].value - match['bts_crop_begin_d'].value
|
||||
pad_end = match['stb_pad_end_d'].value - match['bts_crop_end_d'].value
|
||||
conv.pad[conv.spatial_dims] = [[pad_begin[x], pad_end[x]] for x in conv.spatial_dims]
|
||||
conv['auto_pad'] = None
|
||||
|
||||
graph.add_edges_from([
|
||||
(input.id, conv.id, {'in': 0, **in_edge_attrs}),
|
||||
(conv.id, output.id, {'out': 0, **out_edge_attrs}),
|
||||
])
|
||||
|
||||
|
||||
class DilatedConvolution1DConverter(MiddleReplacementPattern):
|
||||
"""
|
||||
|
67
tools/mo/unit_tests/mo/middle/DilatedConvolution_test.py
Normal file
67
tools/mo/unit_tests/mo/middle/DilatedConvolution_test.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from openvino.tools.mo.middle.DilatedConvolution import DilatedConvolutionConverter
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, result, connect, \
|
||||
regular_op_with_shaped_data, valued_const_with_data
|
||||
|
||||
shape = int64_array([1, 375, 500, 24])
|
||||
nodes = {**regular_op_with_shaped_data('input', shape, {'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**valued_const_with_data('stb_bs', int64_array([1, 32, 32, 1])),
|
||||
**valued_const_with_data('stb_pad_begin', int64_array([0, 32, 32, 0])),
|
||||
**valued_const_with_data('stb_pad_end', int64_array([0, 41, 44, 0])),
|
||||
**regular_op_with_shaped_data('space_to_batch', int64_array([1024, 14, 18, 24]),
|
||||
{'op': 'SpaceToBatch', 'name': 'stb'}),
|
||||
**regular_op_with_shaped_data('conv', int64_array([1024, 12, 16, 24]),
|
||||
{'op': 'Conv2D', 'name': 'conv', 'spatial_dims': int64_array([1, 2]),
|
||||
'dilation': int64_array([1, 1, 1, 1]),
|
||||
'pad': int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])}),
|
||||
**valued_const_with_data('bts_bs', int64_array([1, 32, 32, 1])),
|
||||
**valued_const_with_data('bts_crop_begin', int64_array([0, 0, 0, 0])),
|
||||
**valued_const_with_data('bts_crop_end', int64_array([0, 9, 12, 0])),
|
||||
**regular_op_with_shaped_data('batch_to_space', shape, {'op': 'BatchToSpace', 'name': 'bts'}),
|
||||
**result('result')
|
||||
}
|
||||
|
||||
edges = [*connect('input', '0:space_to_batch'),
|
||||
*connect('stb_bs', '1:space_to_batch'),
|
||||
*connect('stb_pad_begin', '2:space_to_batch'),
|
||||
*connect('stb_pad_end', '3:space_to_batch'),
|
||||
*connect('space_to_batch', '0:conv'),
|
||||
*connect('conv', '0:batch_to_space'),
|
||||
*connect('bts_bs', '1:batch_to_space'),
|
||||
*connect('bts_crop_begin', '2:batch_to_space'),
|
||||
*connect('bts_crop_end', '3:batch_to_space'),
|
||||
*connect('batch_to_space', 'result')
|
||||
]
|
||||
|
||||
ref_nodes = {**regular_op_with_shaped_data('input', shape, {'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('conv', shape,
|
||||
{'op': 'Conv2D', 'name': 'conv', 'spatial_dims': int64_array([1, 2]),
|
||||
'dilation': int64_array([1, 32, 32, 1]), 'auto_pad': None,
|
||||
'pad': int64_array([[0, 0], [32, 32], [32, 32], [0, 0]])}),
|
||||
**result('result')
|
||||
}
|
||||
ref_edges = [*connect('input', '0:conv'),
|
||||
*connect('conv', 'result')
|
||||
]
|
||||
|
||||
|
||||
class DilatedConvolutionTest(unittest.TestCase):
|
||||
def test_dilated_conv_1(self):
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
graph_ref = build_graph(ref_nodes, ref_edges)
|
||||
|
||||
graph.graph['layout'] = 'NHWC'
|
||||
graph.stage = 'middle'
|
||||
|
||||
DilatedConvolutionConverter().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
Loading…
Reference in New Issue
Block a user