Extend MO/nGraph for operation MaxPool-8 (#6776)
* MO update for MaxPool-8 * using attrs in op constructor * remove pads_value attr and pad fusing * added axis and index_element_type parameters * updated mo maxpool-8 * added maxpool-8->maxpool-1 transformation, disabled Pad to MaxPool fusing * added remove_values_output to pooling extractors * moved remove_values_output attribute to pooling infer function * fixed ir_comparator tests * disabled pad to maxpool fusing test * added downgrade transformation test * downgrade transformation update * updated ir reader and tf pooling layer tests * updated onnx pooling layer tests and MO infer unit test * updated ir reader extender * uncommented layer tests code * disabled MaxPool-8 python binding test * comment resolving, removed PadMaxPool fusing * removed test * downgrade transformation fix, MO codestyle changes * removed axis check from downgrade transformation * mark max_pool_test as xfail * updated downgrade transformation test * using OPENVINO_RTTI
This commit is contained in:
parent
a429044038
commit
90a140ae98
@ -13,7 +13,6 @@ namespace pass {
|
|||||||
|
|
||||||
class TRANSFORMATIONS_API PadFusion;
|
class TRANSFORMATIONS_API PadFusion;
|
||||||
class TRANSFORMATIONS_API PadFusionAvgPool;
|
class TRANSFORMATIONS_API PadFusionAvgPool;
|
||||||
class TRANSFORMATIONS_API PadFusionMaxPool;
|
|
||||||
class TRANSFORMATIONS_API PadFusionConvolution;
|
class TRANSFORMATIONS_API PadFusionConvolution;
|
||||||
class TRANSFORMATIONS_API PadFusionConvolutionBackpropData;
|
class TRANSFORMATIONS_API PadFusionConvolutionBackpropData;
|
||||||
class TRANSFORMATIONS_API PadFusionGroupConvolution;
|
class TRANSFORMATIONS_API PadFusionGroupConvolution;
|
||||||
@ -36,19 +35,6 @@ public:
|
|||||||
PadFusionAvgPool();
|
PadFusionAvgPool();
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* @ingroup ie_transformation_common_api
|
|
||||||
* @brief PadFusion transformation replaces following graph:
|
|
||||||
* Pad -> MaxPool to MaxPool, under following conditions
|
|
||||||
* - pad mode is op::PadMode::CONSTANT
|
|
||||||
* - pad value is 0
|
|
||||||
*/
|
|
||||||
class ngraph::pass::PadFusionMaxPool: public ngraph::pass::MatcherPass {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
PadFusionMaxPool();
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
* @brief PadFusion transformation replaces following graph:
|
* @brief PadFusion transformation replaces following graph:
|
||||||
@ -108,7 +94,6 @@ public:
|
|||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
PadFusion() {
|
PadFusion() {
|
||||||
add_matcher<ngraph::pass::PadFusionAvgPool>();
|
add_matcher<ngraph::pass::PadFusionAvgPool>();
|
||||||
add_matcher<ngraph::pass::PadFusionMaxPool>();
|
|
||||||
add_matcher<ngraph::pass::PadFusionConvolution>();
|
add_matcher<ngraph::pass::PadFusionConvolution>();
|
||||||
add_matcher<ngraph::pass::PadFusionConvolutionBackpropData>();
|
add_matcher<ngraph::pass::PadFusionConvolutionBackpropData>();
|
||||||
add_matcher<ngraph::pass::PadFusionGroupConvolution>();
|
add_matcher<ngraph::pass::PadFusionGroupConvolution>();
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API ConvertMaxPool8ToMaxPool1;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief ConvertMaxPool8ToMaxPool1 converts v8::MaxPool into v1::MaxPool.
|
||||||
|
*/
|
||||||
|
class ngraph::pass::ConvertMaxPool8ToMaxPool1 : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ConvertMaxPool8ToMaxPool1");
|
||||||
|
ConvertMaxPool8ToMaxPool1();
|
||||||
|
};
|
@ -77,6 +77,7 @@
|
|||||||
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
|
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
|
||||||
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
|
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
|
||||||
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
|
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
|
||||||
|
#include "transformations/op_conversions/convert_maxpool_downgrade.hpp"
|
||||||
|
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
#include <ngraph/pass/constant_folding.hpp>
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
@ -164,6 +165,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
|||||||
manager.register_pass<ngraph::pass::ConvertGather1ToGather7, false>();
|
manager.register_pass<ngraph::pass::ConvertGather1ToGather7, false>();
|
||||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather8, false>();
|
manager.register_pass<ngraph::pass::ConvertGather7ToGather8, false>();
|
||||||
manager.register_pass<ngraph::pass::ConvertDeformableConv8To1>();
|
manager.register_pass<ngraph::pass::ConvertDeformableConv8To1>();
|
||||||
|
manager.register_pass<ngraph::pass::ConvertMaxPool8ToMaxPool1>();
|
||||||
|
|
||||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
@ -132,48 +132,6 @@ pass::PadFusionAvgPool::PadFusionAvgPool() {
|
|||||||
this->register_matcher(m, callback);
|
this->register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(pass::PadFusionMaxPool, "PadFusionMaxPool", 0);
|
|
||||||
|
|
||||||
pass::PadFusionMaxPool::PadFusionMaxPool() {
|
|
||||||
MATCHER_SCOPE(PadFusionMaxPool);
|
|
||||||
auto data_pattern = pattern::any_input();
|
|
||||||
auto pads_begin_pattern = pattern::wrap_type<opset5::Constant>();
|
|
||||||
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
|
|
||||||
auto pad_value_pattern = pattern::wrap_type<opset5::Constant>();
|
|
||||||
auto pad_node_pattern = pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern,
|
|
||||||
pads_end_pattern, pad_value_pattern},
|
|
||||||
pattern::consumers_count(1));
|
|
||||||
auto max_pool_pattern = pattern::wrap_type<opset5::MaxPool>({pad_node_pattern});
|
|
||||||
|
|
||||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
|
||||||
auto pattern_map = m.get_pattern_value_map();
|
|
||||||
auto data = pattern_map[data_pattern];
|
|
||||||
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
|
||||||
auto pad_value_const = std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pad_value_pattern].get_node_shared_ptr());
|
|
||||||
auto pads_begin = std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
|
|
||||||
auto pads_end = std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_end_pattern].get_node_shared_ptr());
|
|
||||||
auto max_pool = std::dynamic_pointer_cast<opset5::MaxPool>(pattern_map[max_pool_pattern].get_node_shared_ptr());
|
|
||||||
if (!can_be_fused(pad, max_pool, pad_value_const, pads_begin, pads_end))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
Shape new_pads_begin, new_pads_end;
|
|
||||||
std::tie(new_pads_begin, new_pads_end) = new_pooling_pad_values(pads_begin, pads_end, max_pool);
|
|
||||||
auto new_max_pool = std::make_shared<opset5::MaxPool>(data, max_pool->get_strides(),
|
|
||||||
new_pads_begin, new_pads_end,
|
|
||||||
max_pool->get_kernel(), max_pool->get_rounding_type(),
|
|
||||||
op::PadType::EXPLICIT);
|
|
||||||
new_max_pool->set_friendly_name(max_pool->get_friendly_name());
|
|
||||||
|
|
||||||
copy_runtime_info({pad, max_pool}, new_max_pool);
|
|
||||||
replace_node(max_pool, new_max_pool);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto m = std::make_shared<pattern::Matcher>(max_pool_pattern, matcher_name);
|
|
||||||
this->register_matcher(m, callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static std::tuple<CoordinateDiff, CoordinateDiff> new_conv_pad_values(const std::shared_ptr<opset5::Constant>& pads_begin,
|
static std::tuple<CoordinateDiff, CoordinateDiff> new_conv_pad_values(const std::shared_ptr<opset5::Constant>& pads_begin,
|
||||||
const std::shared_ptr<opset5::Constant>& pads_end,
|
const std::shared_ptr<opset5::Constant>& pads_end,
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/op_conversions/convert_maxpool_downgrade.hpp"
|
||||||
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include "itt.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
|
||||||
|
pass::ConvertMaxPool8ToMaxPool1::ConvertMaxPool8ToMaxPool1() {
|
||||||
|
MATCHER_SCOPE(ConvertMaxPool8ToMaxPool1);
|
||||||
|
|
||||||
|
auto maxpool_v8_pattern = pattern::wrap_type<ngraph::opset8::MaxPool>();
|
||||||
|
|
||||||
|
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
auto maxpool_v8_node = std::dynamic_pointer_cast<ngraph::opset8::MaxPool>(m.get_match_root());
|
||||||
|
|
||||||
|
if (!maxpool_v8_node || maxpool_v8_node->get_output_target_inputs(1).size() != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for (auto dilation : maxpool_v8_node->get_dilations())
|
||||||
|
if (dilation != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto maxpool_v1_node = make_shared<ngraph::opset1::MaxPool>(maxpool_v8_node->input_value(0),
|
||||||
|
maxpool_v8_node->get_strides(),
|
||||||
|
maxpool_v8_node->get_pads_begin(),
|
||||||
|
maxpool_v8_node->get_pads_end(),
|
||||||
|
maxpool_v8_node->get_kernel(),
|
||||||
|
maxpool_v8_node->get_rounding_type(),
|
||||||
|
maxpool_v8_node->get_auto_pad());
|
||||||
|
|
||||||
|
auto out_name = ngraph::op::util::create_ie_output_name(maxpool_v8_node->output(0));
|
||||||
|
|
||||||
|
maxpool_v1_node->set_friendly_name(maxpool_v8_node->get_friendly_name());
|
||||||
|
maxpool_v8_node->output(0).replace(maxpool_v1_node->output(0));
|
||||||
|
ngraph::copy_runtime_info(maxpool_v8_node, maxpool_v1_node);
|
||||||
|
maxpool_v8_node->clear_control_dependencies();
|
||||||
|
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
|
maxpool_v1_node->output(0).get_tensor().set_name(out_name);
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = make_shared<pattern::Matcher>(maxpool_v8_pattern, matcher_name);
|
||||||
|
register_matcher(m, callback);
|
||||||
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/op_conversions/convert_maxpool_downgrade.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, ConvertMaxPool8ToMaxPool1) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
|
||||||
|
ngraph::Strides strides{1}, dilations{1};
|
||||||
|
ngraph::Shape pads_begin{0}, pads_end{0}, kernel{1};
|
||||||
|
auto maxpool_8 = std::make_shared<ngraph::opset8::MaxPool>(data, strides, dilations, pads_begin, pads_end,
|
||||||
|
kernel);
|
||||||
|
auto result = std::make_shared<ngraph::opset1::Result>(maxpool_8->output(0));
|
||||||
|
|
||||||
|
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{result}, ngraph::ParameterVector{data});
|
||||||
|
manager.register_pass<ngraph::pass::ConvertMaxPool8ToMaxPool1>();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
|
||||||
|
ngraph::Strides strides{1};
|
||||||
|
ngraph::Shape pads_begin{0}, pads_end{0}, kernel{1};
|
||||||
|
auto maxpool_1 = std::make_shared<ngraph::opset1::MaxPool>(data, strides, pads_begin, pads_end,
|
||||||
|
kernel);
|
||||||
|
auto result = std::make_shared<ngraph::opset1::Result>(maxpool_1->output(0));
|
||||||
|
|
||||||
|
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{result}, ngraph::ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
@ -90,27 +90,6 @@ TEST_F(TransformationTestsF, PadFusionAvgPoolDontExcludePad) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, PadFusionMaxPool) {
|
|
||||||
Shape data_shape{1, 3, 14, 14};
|
|
||||||
{
|
|
||||||
auto data = std::make_shared<opset5::Parameter>(element::i32, data_shape);
|
|
||||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
|
|
||||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
|
||||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
|
||||||
auto max_pool = std::make_shared<opset5::MaxPool>(pad, Strides{1, 1},
|
|
||||||
Shape{0, 0}, Shape{1, 1}, Shape{4, 4});
|
|
||||||
function = std::make_shared<Function>(NodeVector{max_pool}, ParameterVector{data});
|
|
||||||
manager.register_pass<pass::PadFusion>();
|
|
||||||
}
|
|
||||||
{
|
|
||||||
auto data = std::make_shared<opset5::Parameter>(element::i32, data_shape);
|
|
||||||
auto max_pool = std::make_shared<opset5::MaxPool>(data, Strides{1, 1},
|
|
||||||
Shape{1, 1}, Shape{3, 3}, Shape{4, 4},
|
|
||||||
op::RoundingType::FLOOR, op::PadType::EXPLICIT);
|
|
||||||
function_ref = std::make_shared<Function>(NodeVector{max_pool}, ParameterVector{data});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, PadFusionConvolution) {
|
TEST_F(TransformationTestsF, PadFusionConvolution) {
|
||||||
Shape data_shape{1, 3, 14, 14};
|
Shape data_shape{1, 3, 14, 14};
|
||||||
{
|
{
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from mo.back.replacement import BackReplacementPattern
|
from mo.back.replacement import BackReplacementPattern
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
|
from mo.ops.result import Result
|
||||||
|
|
||||||
|
|
||||||
class MaxPool(BackReplacementPattern):
|
class MaxPool(BackReplacementPattern):
|
||||||
@ -25,3 +26,14 @@ class MaxPool(BackReplacementPattern):
|
|||||||
del node['pool_method']
|
del node['pool_method']
|
||||||
if 'exclude_pad' in node:
|
if 'exclude_pad' in node:
|
||||||
del node['exclude_pad']
|
del node['exclude_pad']
|
||||||
|
|
||||||
|
# adding missed outputs for MaxPool node
|
||||||
|
if node.out_port(0).disconnected():
|
||||||
|
output = Result(node.graph, {'name': node.name + '/Result_port_0/',
|
||||||
|
'keep_output_port': node.has_and_set('remove_values_output')}).create_node()
|
||||||
|
node.out_port(0).get_connection().set_destination(output.in_port(0))
|
||||||
|
|
||||||
|
if node.out_port(1).disconnected():
|
||||||
|
output = Result(node.graph, {'name': node.name + '/Result_port_1/',
|
||||||
|
'keep_output_port': node.has_and_set('remove_values_output')}).create_node()
|
||||||
|
node.out_port(1).get_connection().set_destination(output.in_port(0))
|
||||||
|
@ -5,6 +5,7 @@ import logging as log
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.extractor import FrontExtractorOp
|
from mo.front.extractor import FrontExtractorOp
|
||||||
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_autopad
|
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_autopad
|
||||||
from mo.ops.pooling import Pooling
|
from mo.ops.pooling import Pooling
|
||||||
@ -92,9 +93,8 @@ def common_onnx_pool_extractor(node):
|
|||||||
strides = onnx_attr(node, 'strides', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
|
strides = onnx_attr(node, 'strides', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
|
||||||
final_strides = np.array([1, 1, *[x for x in strides]], dtype=np.int64) if strides is not None else None
|
final_strides = np.array([1, 1, *[x for x in strides]], dtype=np.int64) if strides is not None else None
|
||||||
|
|
||||||
dilations = onnx_attr(node, 'dilations', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
|
dilation = onnx_attr(node, 'dilations', 'ints', default=None, dst_type=lambda x: int64_array(x))
|
||||||
assert dilations is None or np.all(dilations == 1),\
|
final_dilation = int64_array([1, 1, *[x for x in dilation]]) if dilation is not None else None
|
||||||
'Node {} has "dilations" attribute with values not equal to 1s which is not supported'.format(node.id)
|
|
||||||
|
|
||||||
# exclude_pad = True only when count_include_pad == 0
|
# exclude_pad = True only when count_include_pad == 0
|
||||||
exclude_pad = onnx_attr(node, 'count_include_pad', 'i', default=0) == 0
|
exclude_pad = onnx_attr(node, 'count_include_pad', 'i', default=0) == 0
|
||||||
@ -127,6 +127,7 @@ def common_onnx_pool_extractor(node):
|
|||||||
'global_pool': global_pooling,
|
'global_pool': global_pooling,
|
||||||
'output_spatial_shape': None,
|
'output_spatial_shape': None,
|
||||||
'rounding_type': rt,
|
'rounding_type': rt,
|
||||||
|
'dilation': final_dilation,
|
||||||
|
|
||||||
'spatial_dims': None,
|
'spatial_dims': None,
|
||||||
'channel_dims': np.array([1], dtype=np.int64),
|
'channel_dims': np.array([1], dtype=np.int64),
|
||||||
|
@ -216,10 +216,13 @@ def convert_deconv_tf_padding_to_str(padding):
|
|||||||
|
|
||||||
|
|
||||||
# TODO eliminate this dependency and pass necessary function as an argument
|
# TODO eliminate this dependency and pass necessary function as an argument
|
||||||
def tf_window_op_pad_infer(input, window, stride, auto_pad, is_deconv=False):
|
def tf_window_op_pad_infer(input, window, stride, auto_pad, is_deconv=False, dilation=None):
|
||||||
if input is None or window is None or stride is None or auto_pad is None:
|
if input is None or window is None or stride is None or auto_pad is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
if dilation is None:
|
||||||
|
dilation = np.ones(len(input), dtype=np.int64)
|
||||||
|
|
||||||
normalized_stride = stride
|
normalized_stride = stride
|
||||||
if is_deconv:
|
if is_deconv:
|
||||||
normalized_stride = 1 / stride
|
normalized_stride = 1 / stride
|
||||||
@ -237,7 +240,7 @@ def tf_window_op_pad_infer(input, window, stride, auto_pad, is_deconv=False):
|
|||||||
high_pad = full_pad - low_pad
|
high_pad = full_pad - low_pad
|
||||||
pad = shape_array([low_pad, high_pad]).transpose()
|
pad = shape_array([low_pad, high_pad]).transpose()
|
||||||
elif auto_pad == 'valid':
|
elif auto_pad == 'valid':
|
||||||
output = np.int64(np.ceil((input - window + 1) / normalized_stride))
|
output = np.int64(np.ceil((input - ((window - 1) * dilation + 1) + 1) / normalized_stride))
|
||||||
pad = np.zeros((len(output), 2), dtype=np.int64)
|
pad = np.zeros((len(output), 2), dtype=np.int64)
|
||||||
else:
|
else:
|
||||||
log.error("Unsupported padding scheme: {}".format(auto_pad))
|
log.error("Unsupported padding scheme: {}".format(auto_pad))
|
||||||
|
@ -22,6 +22,9 @@ def pad_op_transform(graph: Graph, match: dict):
|
|||||||
log.info('The pad node "{}" with pad mode "{}" cannot be fused.'.format(pad_op.soft_get('name'), pad_op.mode))
|
log.info('The pad node "{}" with pad mode "{}" cannot be fused.'.format(pad_op.soft_get('name'), pad_op.mode))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if op.type == 'Pooling' and op.pool_method == 'max':
|
||||||
|
return
|
||||||
|
|
||||||
if pad_op.mode == 'constant':
|
if pad_op.mode == 'constant':
|
||||||
fill_value = pad_op.in_port(3).data.get_value()
|
fill_value = pad_op.in_port(3).data.get_value()
|
||||||
if fill_value is None or fill_value != 0.0:
|
if fill_value is None or fill_value != 0.0:
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mo.front.common.partial_infer.utils import tf_window_op_pad_infer, int64_array, float_array, shape_array, \
|
from mo.front.common.partial_infer.utils import tf_window_op_pad_infer, int64_array, shape_array, \
|
||||||
dynamic_dimension_value, dynamic_dimension
|
dynamic_dimension_value, dynamic_dimension
|
||||||
from mo.front.onnx.extractors.utils import get_backend_pad
|
from mo.front.onnx.extractors.utils import get_backend_pad
|
||||||
from mo.graph.graph import Node, Graph
|
from mo.graph.graph import Node, Graph
|
||||||
|
from mo.middle.passes.convert_data_type import np_data_type_to_destination_type
|
||||||
from mo.ops.op import Op, PermuteAttrs
|
from mo.ops.op import Op, PermuteAttrs
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
from mo.front.extractor import bool_to_str
|
from mo.front.extractor import bool_to_str
|
||||||
@ -46,6 +47,12 @@ class PoolingV2(Op):
|
|||||||
Pooling.pool_infer(node)
|
Pooling.pool_infer(node)
|
||||||
|
|
||||||
|
|
||||||
|
poolings_map = {
|
||||||
|
'max': {'version': 'opset8', 'out_ports_count': 2},
|
||||||
|
'avg': {'version': 'opset1', 'out_ports_count': 1}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Pooling(Op):
|
class Pooling(Op):
|
||||||
op = 'Pooling'
|
op = 'Pooling'
|
||||||
|
|
||||||
@ -53,10 +60,10 @@ class Pooling(Op):
|
|||||||
super().__init__(graph, {
|
super().__init__(graph, {
|
||||||
'type': self.op,
|
'type': self.op,
|
||||||
'op': self.op,
|
'op': self.op,
|
||||||
'version': 'opset1',
|
'version': poolings_map[attrs.get('pool_method')]['version'],
|
||||||
'infer': self.infer,
|
'infer': self.infer,
|
||||||
'in_ports_count': 1,
|
'in_ports_count': 1,
|
||||||
'out_ports_count': 1,
|
'out_ports_count': 1 if attrs.get('version') == 'opset1' else poolings_map[attrs.get('pool_method')]['out_ports_count']
|
||||||
}, attrs)
|
}, attrs)
|
||||||
|
|
||||||
def backend_attrs(self):
|
def backend_attrs(self):
|
||||||
@ -71,6 +78,11 @@ class Pooling(Op):
|
|||||||
|
|
||||||
'rounding_type',
|
'rounding_type',
|
||||||
('auto_pad', lambda node: node.auto_pad if node.has_valid('auto_pad') else 'explicit'),
|
('auto_pad', lambda node: node.auto_pad if node.has_valid('auto_pad') else 'explicit'),
|
||||||
|
|
||||||
|
('dilations', lambda node: ','.join(map(str, node['dilation'][node.spatial_dims]))),
|
||||||
|
'axis',
|
||||||
|
|
||||||
|
('index_element_type', lambda node: np_data_type_to_destination_type(node.index_element_type))
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -105,14 +117,26 @@ class Pooling(Op):
|
|||||||
node['window'] = np.zeros(len(input_shape), dtype=np.int64)
|
node['window'] = np.zeros(len(input_shape), dtype=np.int64)
|
||||||
node.window[node.spatial_dims] = input_spatial_shape
|
node.window[node.spatial_dims] = input_spatial_shape
|
||||||
|
|
||||||
|
if not node.has_valid('dilation'):
|
||||||
|
node['dilation'] = np.ones(len(input_shape))
|
||||||
|
|
||||||
|
if not node.has_valid('axis'):
|
||||||
|
node['axis'] = 0
|
||||||
|
|
||||||
|
if not node.has_valid('index_element_type'):
|
||||||
|
node['index_element_type'] = np.int64
|
||||||
|
|
||||||
window_spatial_shape = node.window[node.spatial_dims]
|
window_spatial_shape = node.window[node.spatial_dims]
|
||||||
stride_spatial = node.stride[node.spatial_dims]
|
stride_spatial = node.stride[node.spatial_dims]
|
||||||
|
dilation_spatial = node.dilation[node.spatial_dims]
|
||||||
assert any(stride_spatial), 'Stride can not be zero in node {}'.format(node.id)
|
assert any(stride_spatial), 'Stride can not be zero in node {}'.format(node.id)
|
||||||
|
|
||||||
if node.has_valid('auto_pad') and node.auto_pad != 'explicit':
|
if node.has_valid('auto_pad') and node.auto_pad != 'explicit':
|
||||||
node.pad_spatial_shape, node.output_spatial_shape = tf_window_op_pad_infer(input_spatial_shape,
|
node.pad_spatial_shape, node.output_spatial_shape = tf_window_op_pad_infer(input=input_spatial_shape,
|
||||||
window_spatial_shape,
|
window=window_spatial_shape,
|
||||||
stride_spatial, node.auto_pad)
|
stride=stride_spatial,
|
||||||
|
auto_pad=node.auto_pad,
|
||||||
|
dilation=dilation_spatial)
|
||||||
pad = np.zeros((len(input_shape), 2), dtype=np.int64)
|
pad = np.zeros((len(input_shape), 2), dtype=np.int64)
|
||||||
pad[node.spatial_dims] = node.pad_spatial_shape
|
pad[node.spatial_dims] = node.pad_spatial_shape
|
||||||
node.pad = pad
|
node.pad = pad
|
||||||
@ -124,7 +148,8 @@ class Pooling(Op):
|
|||||||
if node.soft_get('pooling_convention') == 'full' or node.soft_get('rounding_type') == 'ceil':
|
if node.soft_get('pooling_convention') == 'full' or node.soft_get('rounding_type') == 'ceil':
|
||||||
rounding = np.ceil
|
rounding = np.ceil
|
||||||
|
|
||||||
padded_spatial_shape = input_spatial_shape + pad_spatial_shape - window_spatial_shape
|
padded_spatial_shape = input_spatial_shape + pad_spatial_shape - ((window_spatial_shape - 1) *
|
||||||
|
dilation_spatial + 1)
|
||||||
if np.any(padded_spatial_shape < 0):
|
if np.any(padded_spatial_shape < 0):
|
||||||
raise Error("Data after padding has dimension less than window size. " +
|
raise Error("Data after padding has dimension less than window size. " +
|
||||||
"Possible reason of error is incorrectly specified model input shape(s).")
|
"Possible reason of error is incorrectly specified model input shape(s).")
|
||||||
@ -147,8 +172,15 @@ class Pooling(Op):
|
|||||||
output_shape[node.spatial_dims] = node.output_spatial_shape
|
output_shape[node.spatial_dims] = node.output_spatial_shape
|
||||||
node.out_port(0).data.set_shape(output_shape)
|
node.out_port(0).data.set_shape(output_shape)
|
||||||
|
|
||||||
|
if len(node.out_ports()) == 2 and not node.out_port(1).disconnected():
|
||||||
|
node.out_port(1).data.set_shape(output_shape)
|
||||||
|
|
||||||
|
if node.has_and_set('pool_method') and node['pool_method'] == 'max':
|
||||||
|
node['remove_values_output'] = True
|
||||||
|
|
||||||
# Add permute_attrs
|
# Add permute_attrs
|
||||||
PermuteAttrs.create_permute_attrs(node, attrs=[('pad', 'input:0'),
|
PermuteAttrs.create_permute_attrs(node, attrs=[('pad', 'input:0'),
|
||||||
('stride', 'input:0'),
|
('stride', 'input:0'),
|
||||||
('window', 'input:0'),
|
('window', 'input:0'),
|
||||||
('spatial_dims', 'input:0')])
|
('spatial_dims', 'input:0'),
|
||||||
|
('dilation', 'input:0')])
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
|
from mo.middle.passes.convert_data_type import destination_type_to_np_data_type
|
||||||
from mo.utils.ir_reader.extender import Extender
|
from mo.utils.ir_reader.extender import Extender
|
||||||
|
|
||||||
|
|
||||||
@ -27,16 +28,23 @@ class MaxPool_extender(Extender):
|
|||||||
|
|
||||||
|
|
||||||
def common_pool_extender(op: Node):
|
def common_pool_extender(op: Node):
|
||||||
for attr in ['strides', 'pads_begin', 'pads_end', 'kernel']:
|
for attr in ['strides', 'pads_begin', 'pads_end', 'kernel', 'dilations']:
|
||||||
Extender.attr_to_list(op, attr)
|
Extender.attr_to_list(op, attr)
|
||||||
op['stride'] = int64_array([1, 1] + op.strides)
|
op['stride'] = int64_array([1, 1] + op.strides)
|
||||||
op['window'] = int64_array([1, 1] + op.kernel)
|
op['window'] = int64_array([1, 1] + op.kernel)
|
||||||
op['kernel_spatial'] = op.kernel
|
op['kernel_spatial'] = op.kernel
|
||||||
op['output_spatial_shape'] = None
|
op['output_spatial_shape'] = None
|
||||||
|
|
||||||
|
if op.has_valid('dilations'):
|
||||||
|
op['dilation'] = int64_array([1, 1] + op.dilations)
|
||||||
|
if op.has_valid('index_element_type'):
|
||||||
|
op['index_element_type'] = destination_type_to_np_data_type(op.index_element_type)
|
||||||
|
|
||||||
op['batch_dims'] = int64_array([0]),
|
op['batch_dims'] = int64_array([0]),
|
||||||
op['channel_dims'] = int64_array([1]),
|
op['channel_dims'] = int64_array([1]),
|
||||||
|
|
||||||
|
op['pool_method'] = 'max' if op.type is 'MaxPool' else 'avg'
|
||||||
|
|
||||||
dim = len(op.pads_begin)
|
dim = len(op.pads_begin)
|
||||||
|
|
||||||
assert dim in (1, 2, 3), '{}D {} not supported! Node name: {}'.format(dim, op.soft_get('type'), op.soft_get('name', op.id))
|
assert dim in (1, 2, 3), '{}D {} not supported! Node name: {}'.format(dim, op.soft_get('type'), op.soft_get('name', op.id))
|
||||||
|
@ -167,3 +167,30 @@ class TestPoolingPartialInfer(unittest.TestCase):
|
|||||||
|
|
||||||
with self.assertRaises(Error):
|
with self.assertRaises(Error):
|
||||||
Pooling.infer(pool_node)
|
Pooling.infer(pool_node)
|
||||||
|
|
||||||
|
def test_pooling_infer_with_dilations(self):
|
||||||
|
graph = build_graph(nodes_attributes,
|
||||||
|
[('node_1', 'pool'),
|
||||||
|
('pool', 'node_2'),
|
||||||
|
('node_2', 'op_output')
|
||||||
|
],
|
||||||
|
{'node_2': {'shape': None},
|
||||||
|
'node_1': {'shape': np.array([1, 3, 256, 256])},
|
||||||
|
'pool': {'window': np.array([1, 1, 2, 2]), 'stride': np.array([1, 1, 2, 2]),
|
||||||
|
'pad': np.array([[0, 0], [0, 0], [0, 0], [1, 1]]),
|
||||||
|
'pad_spatial_shape': np.array([[0, 0], [1, 1]]),
|
||||||
|
'pool_method': 'max', 'exclude_pad': False, 'global_pool': False,
|
||||||
|
'output_spatial_shape': None, 'output_shape': None,
|
||||||
|
'kernel_spatial': np.array([2, 2]), 'spatial_dims': np.array([2, 3]),
|
||||||
|
'channel_dims': np.array([1]), 'batch_dims': np.array([0]),
|
||||||
|
'pooling_convention': 'full', 'dilation': np.array([1, 1, 2, 2]),
|
||||||
|
'auto_pad': 'valid'}
|
||||||
|
})
|
||||||
|
|
||||||
|
pool_node = Node(graph, 'pool')
|
||||||
|
|
||||||
|
Pooling.infer(pool_node)
|
||||||
|
exp_shape = np.array([1, 3, 127, 127])
|
||||||
|
res_shape = graph.node['node_2']['shape']
|
||||||
|
for i in range(0, len(exp_shape)):
|
||||||
|
self.assertEqual(exp_shape[i], res_shape[i])
|
||||||
|
@ -144,3 +144,4 @@ xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike")
|
|||||||
xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElement, OptionalGetElement")
|
xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElement, OptionalGetElement")
|
||||||
xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support")
|
xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support")
|
||||||
xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name")
|
xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name")
|
||||||
|
xfail_issue_54663 = xfail_test(reason="Disabled until MaxPool-8 is supported on CPU")
|
||||||
|
@ -11,6 +11,7 @@ from openvino.impl.op import Constant, Parameter
|
|||||||
from tests.runtime import get_runtime
|
from tests.runtime import get_runtime
|
||||||
|
|
||||||
from tests import xfail_issue_67415
|
from tests import xfail_issue_67415
|
||||||
|
from tests import xfail_issue_54663
|
||||||
|
|
||||||
|
|
||||||
def binary_op(op_str, a, b):
|
def binary_op(op_str, a, b):
|
||||||
@ -543,7 +544,7 @@ def test_select():
|
|||||||
expected = np.array([[5, 8]])
|
expected = np.array([[5, 8]])
|
||||||
assert np.allclose(result, expected)
|
assert np.allclose(result, expected)
|
||||||
|
|
||||||
|
@xfail_issue_54663
|
||||||
def test_max_pool():
|
def test_max_pool():
|
||||||
# test 1d
|
# test 1d
|
||||||
element_type = Type.f32
|
element_type = Type.f32
|
||||||
|
@ -149,3 +149,4 @@ xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike")
|
|||||||
xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElement, OptionalGetElement")
|
xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElement, OptionalGetElement")
|
||||||
xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support")
|
xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support")
|
||||||
xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name")
|
xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name")
|
||||||
|
xfail_issue_54663 = xfail_test(reason="Disabled until MaxPool-8 is supported on CPU")
|
||||||
|
@ -10,6 +10,8 @@ from ngraph.impl import AxisSet, Function, Shape, Type
|
|||||||
from ngraph.impl.op import Constant, Parameter
|
from ngraph.impl.op import Constant, Parameter
|
||||||
from tests_compatibility.runtime import get_runtime
|
from tests_compatibility.runtime import get_runtime
|
||||||
|
|
||||||
|
from tests_compatibility import xfail_issue_54663
|
||||||
|
|
||||||
|
|
||||||
def binary_op(op_str, a, b):
|
def binary_op(op_str, a, b):
|
||||||
|
|
||||||
@ -541,6 +543,7 @@ def test_select():
|
|||||||
assert np.allclose(result, expected)
|
assert np.allclose(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@xfail_issue_54663
|
||||||
def test_max_pool():
|
def test_max_pool():
|
||||||
# test 1d
|
# test 1d
|
||||||
element_type = Type.f32
|
element_type = Type.f32
|
||||||
|
@ -132,6 +132,7 @@ class TestPooling(OnnxRuntimeLayerTest):
|
|||||||
'rounding_type': 'ceil' if auto_pad != 'NOTSET' or ceil else 'floor',
|
'rounding_type': 'ceil' if auto_pad != 'NOTSET' or ceil else 'floor',
|
||||||
'auto_pad': None},
|
'auto_pad': None},
|
||||||
'node_data': {'shape': out_shape, 'kind': 'data'},
|
'node_data': {'shape': out_shape, 'kind': 'data'},
|
||||||
|
'node_indicies_data': {'shape': out_shape, 'kind': 'data'},
|
||||||
'input_const_data': {'kind': 'data', 'value': constant.flatten()},
|
'input_const_data': {'kind': 'data', 'value': constant.flatten()},
|
||||||
'const': {'kind': 'op', 'type': 'Const'},
|
'const': {'kind': 'op', 'type': 'Const'},
|
||||||
'const_data': {'shape': out_shape, 'kind': 'data'},
|
'const_data': {'shape': out_shape, 'kind': 'data'},
|
||||||
@ -141,21 +142,24 @@ class TestPooling(OnnxRuntimeLayerTest):
|
|||||||
}
|
}
|
||||||
if op == 'AveragePool':
|
if op == 'AveragePool':
|
||||||
nodes_attributes['node']['type'] = 'AvgPool'
|
nodes_attributes['node']['type'] = 'AvgPool'
|
||||||
nodes_attributes['node']['exclude-pad'] = 'true' if count_include_pad == 0 else 'false'
|
nodes_attributes['node']['exclude-pad'] = True if count_include_pad == 0 else False
|
||||||
else:
|
else:
|
||||||
nodes_attributes['node']['type'] = 'MaxPool'
|
nodes_attributes['node']['type'] = 'MaxPool'
|
||||||
|
|
||||||
|
edges = [('input', 'input_data'),
|
||||||
|
('input_data', 'node'),
|
||||||
|
('node', 'node_data', {'out': 0}),
|
||||||
|
('input_const_data', 'const'),
|
||||||
|
('const', 'const_data'),
|
||||||
|
('node_data', 'concat'),
|
||||||
|
('const_data', 'concat'),
|
||||||
|
('concat', 'concat_data'),
|
||||||
|
('concat_data', 'result')]
|
||||||
|
if op == "MaxPool":
|
||||||
|
edges.append(('node', 'node_indicies_data', {'out': 1}))
|
||||||
ref_net = build_graph(nodes_attributes,
|
ref_net = build_graph(nodes_attributes,
|
||||||
[('input', 'input_data'),
|
edges,
|
||||||
('input_data', 'node'),
|
nodes_with_edges_only=True)
|
||||||
('node', 'node_data'),
|
|
||||||
('input_const_data', 'const'),
|
|
||||||
('const', 'const_data'),
|
|
||||||
('node_data', 'concat'),
|
|
||||||
('const_data', 'concat'),
|
|
||||||
('concat', 'concat_data'),
|
|
||||||
('concat_data', 'result')
|
|
||||||
])
|
|
||||||
|
|
||||||
return onnx_net, ref_net
|
return onnx_net, ref_net
|
||||||
|
|
||||||
|
@ -73,7 +73,8 @@ class TestPooling(CommonTFLayerTest):
|
|||||||
'pooling': {'kernel': kernel_size, 'pads_begin': pads_begin, 'pads_end': pads_end,
|
'pooling': {'kernel': kernel_size, 'pads_begin': pads_begin, 'pads_end': pads_end,
|
||||||
'strides': strides, 'kind': 'op', 'type': None},
|
'strides': strides, 'kind': 'op', 'type': None},
|
||||||
'pooling_data': {'shape': out_shape, 'kind': 'data'},
|
'pooling_data': {'shape': out_shape, 'kind': 'data'},
|
||||||
'result': {'kind': 'op', 'type': 'Result'}
|
'result': {'kind': 'op', 'type': 'Result'},
|
||||||
|
'pooling_indicies_data': {'kind': 'data', 'shape': out_shape}
|
||||||
}
|
}
|
||||||
|
|
||||||
if method == 'avg':
|
if method == 'avg':
|
||||||
@ -81,12 +82,17 @@ class TestPooling(CommonTFLayerTest):
|
|||||||
elif method == 'max':
|
elif method == 'max':
|
||||||
nodes_attributes['pooling']['type'] = 'MaxPool'
|
nodes_attributes['pooling']['type'] = 'MaxPool'
|
||||||
|
|
||||||
|
edges = [('input', 'input_data'),
|
||||||
|
('input_data', 'pooling'),
|
||||||
|
('pooling', 'pooling_data', {'out': 0}),
|
||||||
|
('pooling_data', 'result')]
|
||||||
|
|
||||||
|
if method == 'max':
|
||||||
|
edges.append(('pooling', 'pooling_indicies_data', {'out': 1}))
|
||||||
|
|
||||||
ref_net = build_graph(nodes_attributes,
|
ref_net = build_graph(nodes_attributes,
|
||||||
[('input', 'input_data'),
|
edges=edges,
|
||||||
('input_data', 'pooling'),
|
nodes_with_edges_only=True)
|
||||||
('pooling', 'pooling_data'),
|
|
||||||
('pooling_data', 'result')
|
|
||||||
])
|
|
||||||
|
|
||||||
return tf_net, ref_net
|
return tf_net, ref_net
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user