[ONNX Importer] 🕸️ Make global pools more dynamic (#6122)

This commit is contained in:
Tomasz Socha
2021-06-11 11:27:58 +02:00
committed by GitHub
parent 8830dcabf3
commit 167a5519b1
2 changed files with 32 additions and 36 deletions

View File

@@ -19,19 +19,6 @@ namespace ngraph
{
OutputVector global_average_pool(const Node& node)
{
auto data = node.get_ng_inputs()[0];
auto data_rank = data.get_partial_shape().rank();
NGRAPH_CHECK(data_rank.is_static(),
"The input data tensor's rank has to be known (static)");
auto data_rank_value = data_rank.get_length();
NGRAPH_CHECK(data_rank_value > 2,
"The input data tensor's rank has to be greater than 2."
"Provided data rank is: ",
data_rank_value);
// Generate axes for reduce operation which contain all spatial dims indexes.
// Examples:
// Input shape: [N, C, H, W]
@@ -41,11 +28,22 @@ namespace ngraph
// Input shape: [N, C, H, W, D]
// Input spatial dimensions are H, W and D
// Expected spatial dims indexes: [2, 3, 4]
size_t data_spatial_rank = data_rank_value - 2;
auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
auto reduce_axes = default_opset::Constant::create(
element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
auto data = node.get_ng_inputs()[0];
const auto zero_node =
default_opset::Constant::create(element::i64, Shape{}, {0});
const auto one_node =
default_opset::Constant::create(element::i64, Shape{}, {1});
const auto two_node =
default_opset::Constant::create(element::i64, Shape{}, {2});
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
const auto data_rank = std::make_shared<default_opset::ShapeOf>(data_shape);
const auto data_rank_as_scalar =
std::make_shared<default_opset::Squeeze>(data_rank);
const auto reduce_axes = std::make_shared<default_opset::Range>(
two_node, data_rank_as_scalar, one_node, element::i64);
return {std::make_shared<default_opset::ReduceMean>(data, reduce_axes, true)};
}

View File

@@ -19,19 +19,6 @@ namespace ngraph
{
OutputVector global_max_pool(const Node& node)
{
auto data = node.get_ng_inputs()[0];
auto data_rank = data.get_partial_shape().rank();
NGRAPH_CHECK(data_rank.is_static(),
"The input data tensor's rank has to be known (static)");
auto data_rank_value = data_rank.get_length();
NGRAPH_CHECK(data_rank_value > 2,
"The input data tensor's rank has to be greater than 2."
"Provided data rank is: ",
data_rank_value);
// Generate axes for reduce operation which contain all spatial dims indexes.
// Examples:
// Input shape: [N, C, H, W]
@@ -41,11 +28,22 @@ namespace ngraph
// Input shape: [N, C, H, W, D]
// Input spatial dimensions are H, W and D
// Expected spatial dims indexes: [2, 3, 4]
size_t data_spatial_rank = data_rank_value - 2;
auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
auto reduce_axes = default_opset::Constant::create(
element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
auto data = node.get_ng_inputs()[0];
const auto zero_node =
default_opset::Constant::create(element::i64, Shape{}, {0});
const auto one_node =
default_opset::Constant::create(element::i64, Shape{}, {1});
const auto two_node =
default_opset::Constant::create(element::i64, Shape{}, {2});
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
const auto data_rank = std::make_shared<default_opset::ShapeOf>(data_shape);
const auto data_rank_as_scalar =
std::make_shared<default_opset::Squeeze>(data_rank);
const auto reduce_axes = std::make_shared<default_opset::Range>(
two_node, data_rank_as_scalar, one_node, element::i64);
return {std::make_shared<default_opset::ReduceMax>(data, reduce_axes, true)};
}