[ONNX Importer] 🕸️ Make global pools more dynamic (#6122)
This commit is contained in:
@@ -19,19 +19,6 @@ namespace ngraph
|
||||
{
|
||||
OutputVector global_average_pool(const Node& node)
|
||||
{
|
||||
auto data = node.get_ng_inputs()[0];
|
||||
auto data_rank = data.get_partial_shape().rank();
|
||||
|
||||
NGRAPH_CHECK(data_rank.is_static(),
|
||||
"The input data tensor's rank has to be known (static)");
|
||||
|
||||
auto data_rank_value = data_rank.get_length();
|
||||
|
||||
NGRAPH_CHECK(data_rank_value > 2,
|
||||
"The input data tensor's rank has to be greater than 2."
|
||||
"Provided data rank is: ",
|
||||
data_rank_value);
|
||||
|
||||
// Generate axes for reduce operation which contain all spatial dims indexes.
|
||||
// Examples:
|
||||
// Input shape: [N, C, H, W]
|
||||
@@ -41,11 +28,22 @@ namespace ngraph
|
||||
// Input shape: [N, C, H, W, D]
|
||||
// Input spatial dimensions are H, W and D
|
||||
// Expected spatial dims indexes: [2, 3, 4]
|
||||
size_t data_spatial_rank = data_rank_value - 2;
|
||||
auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
|
||||
std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
|
||||
auto reduce_axes = default_opset::Constant::create(
|
||||
element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
|
||||
auto data = node.get_ng_inputs()[0];
|
||||
|
||||
const auto zero_node =
|
||||
default_opset::Constant::create(element::i64, Shape{}, {0});
|
||||
const auto one_node =
|
||||
default_opset::Constant::create(element::i64, Shape{}, {1});
|
||||
const auto two_node =
|
||||
default_opset::Constant::create(element::i64, Shape{}, {2});
|
||||
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
const auto data_rank = std::make_shared<default_opset::ShapeOf>(data_shape);
|
||||
const auto data_rank_as_scalar =
|
||||
std::make_shared<default_opset::Squeeze>(data_rank);
|
||||
|
||||
const auto reduce_axes = std::make_shared<default_opset::Range>(
|
||||
two_node, data_rank_as_scalar, one_node, element::i64);
|
||||
|
||||
return {std::make_shared<default_opset::ReduceMean>(data, reduce_axes, true)};
|
||||
}
|
||||
|
||||
@@ -19,19 +19,6 @@ namespace ngraph
|
||||
{
|
||||
OutputVector global_max_pool(const Node& node)
|
||||
{
|
||||
auto data = node.get_ng_inputs()[0];
|
||||
auto data_rank = data.get_partial_shape().rank();
|
||||
|
||||
NGRAPH_CHECK(data_rank.is_static(),
|
||||
"The input data tensor's rank has to be known (static)");
|
||||
|
||||
auto data_rank_value = data_rank.get_length();
|
||||
|
||||
NGRAPH_CHECK(data_rank_value > 2,
|
||||
"The input data tensor's rank has to be greater than 2."
|
||||
"Provided data rank is: ",
|
||||
data_rank_value);
|
||||
|
||||
// Generate axes for reduce operation which contain all spatial dims indexes.
|
||||
// Examples:
|
||||
// Input shape: [N, C, H, W]
|
||||
@@ -41,11 +28,22 @@ namespace ngraph
|
||||
// Input shape: [N, C, H, W, D]
|
||||
// Input spatial dimensions are H, W and D
|
||||
// Expected spatial dims indexes: [2, 3, 4]
|
||||
size_t data_spatial_rank = data_rank_value - 2;
|
||||
auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
|
||||
std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
|
||||
auto reduce_axes = default_opset::Constant::create(
|
||||
element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
|
||||
auto data = node.get_ng_inputs()[0];
|
||||
|
||||
const auto zero_node =
|
||||
default_opset::Constant::create(element::i64, Shape{}, {0});
|
||||
const auto one_node =
|
||||
default_opset::Constant::create(element::i64, Shape{}, {1});
|
||||
const auto two_node =
|
||||
default_opset::Constant::create(element::i64, Shape{}, {2});
|
||||
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
const auto data_rank = std::make_shared<default_opset::ShapeOf>(data_shape);
|
||||
const auto data_rank_as_scalar =
|
||||
std::make_shared<default_opset::Squeeze>(data_rank);
|
||||
|
||||
const auto reduce_axes = std::make_shared<default_opset::Range>(
|
||||
two_node, data_rank_as_scalar, one_node, element::i64);
|
||||
|
||||
return {std::make_shared<default_opset::ReduceMax>(data, reduce_axes, true)};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user