Add ceil_mode for Max and Avg pooling (#2965)

This commit is contained in:
Bartosz Sledz
2020-11-16 15:16:24 +01:00
committed by GitHub
parent 5d0bfbb47f
commit 749d70bb63
5 changed files with 18 additions and 5 deletions

View File

@@ -53,6 +53,13 @@ namespace ngraph
/// (height, width, depth).
Strides get_dilations(const Node& node, const std::size_t kernel_rank = 0UL);
/// \brief Gets the 'ceil_mode' (rounding type) attribute value.
///
/// \param[in] node The ONNX node we query for attribute.
///
/// \return The nGraph RoundingType object representing 'ceil_mode' attribute value.
ngraph::op::RoundingType get_rounding_type(const Node& node);
/// \brief Get padding values for the operation described by an ONNX node.
/// \details Values are taken from the `pads` attribute.
///

View File

@@ -73,6 +73,7 @@ namespace ngraph
Shape m_padding_below;
Shape m_padding_above;
ngraph::op::PadType m_auto_pad;
ngraph::op::RoundingType m_rounding_type;
};
///

View File

@@ -99,6 +99,12 @@ namespace ngraph
return detail::get_attribute_value(node, "dilations", kernel_rank);
}
ngraph::op::RoundingType get_rounding_type(const Node& node)
{
return static_cast<ngraph::op::RoundingType>(
node.get_attribute_value<std::int64_t>("ceil_mode", 0));
}
ngraph::op::PadType get_auto_pad(const Node& node)
{
// Default value means use explicitly provided padding values.

View File

@@ -34,6 +34,7 @@ namespace ngraph
, m_strides{convpool::get_strides(node)}
, m_dilations{convpool::get_dilations(node)}
, m_auto_pad{convpool::get_auto_pad(node)}
, m_rounding_type{convpool::get_rounding_type(node)}
{
const auto paddings = convpool::get_pads(node);
const CoordinateDiff& padding_above{paddings.second};
@@ -52,7 +53,7 @@ namespace ngraph
m_padding_above,
m_kernel_shape,
!count_include_pad,
ngraph::op::RoundingType::FLOOR,
m_rounding_type,
m_auto_pad)};
}
@@ -63,7 +64,7 @@ namespace ngraph
m_padding_below,
m_padding_above,
m_kernel_shape,
ngraph::op::RoundingType::FLOOR,
m_rounding_type,
m_auto_pad)};
}

View File

@@ -194,9 +194,7 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu",
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
(xfail_issue_33616,
"OnnxBackendNodeModelTest.test_maxpool_2d_ceil_cpu",
"OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu",
"OnnxBackendNodeModelTest.test_averagepool_2d_ceil_cpu"),
"OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"),
(xfail_issue_38086,
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_expanded_cpu",
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_expanded_cpu",