Add ceil_mode for Max and Avg pooling (#2965)
This commit is contained in:
@@ -53,6 +53,13 @@ namespace ngraph
|
||||
/// (height, width, depth).
|
||||
Strides get_dilations(const Node& node, const std::size_t kernel_rank = 0UL);
|
||||
|
||||
/// \brief Gets the 'ceil_mode' (rounding type) attribute value.
|
||||
///
|
||||
/// \param[in] node The ONNX node we query for attribute.
|
||||
///
|
||||
/// \return The nGraph RoundingType object representing 'ceil_mode' attribute value.
|
||||
ngraph::op::RoundingType get_rounding_type(const Node& node);
|
||||
|
||||
/// \brief Get padding values for the operation described by an ONNX node.
|
||||
/// \details Values are taken from the `pads` attribute.
|
||||
///
|
||||
|
||||
@@ -73,6 +73,7 @@ namespace ngraph
|
||||
Shape m_padding_below;
|
||||
Shape m_padding_above;
|
||||
ngraph::op::PadType m_auto_pad;
|
||||
ngraph::op::RoundingType m_rounding_type;
|
||||
};
|
||||
|
||||
///
|
||||
|
||||
@@ -99,6 +99,12 @@ namespace ngraph
|
||||
return detail::get_attribute_value(node, "dilations", kernel_rank);
|
||||
}
|
||||
|
||||
ngraph::op::RoundingType get_rounding_type(const Node& node)
|
||||
{
|
||||
return static_cast<ngraph::op::RoundingType>(
|
||||
node.get_attribute_value<std::int64_t>("ceil_mode", 0));
|
||||
}
|
||||
|
||||
ngraph::op::PadType get_auto_pad(const Node& node)
|
||||
{
|
||||
// Default value means use explicitly provided padding values.
|
||||
|
||||
@@ -34,6 +34,7 @@ namespace ngraph
|
||||
, m_strides{convpool::get_strides(node)}
|
||||
, m_dilations{convpool::get_dilations(node)}
|
||||
, m_auto_pad{convpool::get_auto_pad(node)}
|
||||
, m_rounding_type{convpool::get_rounding_type(node)}
|
||||
{
|
||||
const auto paddings = convpool::get_pads(node);
|
||||
const CoordinateDiff& padding_above{paddings.second};
|
||||
@@ -52,7 +53,7 @@ namespace ngraph
|
||||
m_padding_above,
|
||||
m_kernel_shape,
|
||||
!count_include_pad,
|
||||
ngraph::op::RoundingType::FLOOR,
|
||||
m_rounding_type,
|
||||
m_auto_pad)};
|
||||
}
|
||||
|
||||
@@ -63,7 +64,7 @@ namespace ngraph
|
||||
m_padding_below,
|
||||
m_padding_above,
|
||||
m_kernel_shape,
|
||||
ngraph::op::RoundingType::FLOOR,
|
||||
m_rounding_type,
|
||||
m_auto_pad)};
|
||||
}
|
||||
|
||||
|
||||
@@ -194,9 +194,7 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu",
|
||||
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
|
||||
(xfail_issue_33616,
|
||||
"OnnxBackendNodeModelTest.test_maxpool_2d_ceil_cpu",
|
||||
"OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu",
|
||||
"OnnxBackendNodeModelTest.test_averagepool_2d_ceil_cpu"),
|
||||
"OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"),
|
||||
(xfail_issue_38086,
|
||||
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_expanded_cpu",
|
||||
|
||||
Reference in New Issue
Block a user