From 328d852f5a2f5475970bbcee1dd13e0ff3a44cbe Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Tue, 30 May 2023 12:29:23 +0200 Subject: [PATCH] [ShapeInference] Reduce* ops shape inference review and update (#17677) * Tests * Add eval_lower/upper support to ReduceMax * Add support for ITensorAccessor in reduce shape infer * Add tests for duplicated axes and output shapes size * Push to output_shapes instead final copy to vector * Remove old shape_infer API * Move axes rank validation to shape_infer * Restore shape_infer API for GPU --- src/core/include/openvino/op/reduce_max.hpp | 2 + .../include/reduce_shape_inference.hpp | 103 +++++++++------ src/core/src/op/reduce_max.cpp | 9 ++ .../util/arithmetic_reductions_keep_dims.cpp | 7 - .../op/util/logical_reduction_keep_dims.cpp | 7 - src/core/src/op/util/reduction_base.cpp | 7 +- src/core/tests/type_prop/reduce_max.cpp | 11 ++ src/core/tests/type_prop/reduce_min.cpp | 11 ++ src/core/tests/type_prop/reduce_ops.hpp | 123 +++++++++++++++++- src/core/tests/type_prop/reduce_prod.cpp | 11 ++ .../utils/shape_inference/shape_inference.cpp | 18 +-- .../reduce_shape_inference_test.cpp | 107 +++++++++++++++ .../unit/shape_inference_test/reduce_test.cpp | 24 ---- 13 files changed, 351 insertions(+), 89 deletions(-) create mode 100644 src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_shape_inference_test.cpp delete mode 100644 src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_test.cpp diff --git a/src/core/include/openvino/op/reduce_max.hpp b/src/core/include/openvino/op/reduce_max.hpp index 499dec82bb9..594450d69d9 100644 --- a/src/core/include/openvino/op/reduce_max.hpp +++ b/src/core/include/openvino/op/reduce_max.hpp @@ -30,6 +30,8 @@ public: bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; OPENVINO_SUPPRESS_DEPRECATED_END bool has_evaluate() const override; + bool evaluate_lower(TensorVector& outputs) const override; + bool evaluate_upper(TensorVector& outputs) const override; }; } // namespace v1 } // namespace op diff --git a/src/core/shape_inference/include/reduce_shape_inference.hpp b/src/core/shape_inference/include/reduce_shape_inference.hpp index e24b8d94d6e..035659f1e4e 100644 --- a/src/core/shape_inference/include/reduce_shape_inference.hpp +++ b/src/core/shape_inference/include/reduce_shape_inference.hpp @@ -9,61 +9,88 @@ #include "utils.hpp" -template -inline void dynamic_inference(const T& input_shape, T& output_shape, bool keep_dims) { - OPENVINO_THROW("This code should be executed only for PartialShape class"); -} +template +std::vector reduce_shape_infer(const ov::op::util::ReductionBase* op, + bool keep_dims, + const std::vector& input_shapes, + const ov::ITensorAccessor& tensor_accessor = ov::make_tensor_accessor()) { + NODE_VALIDATION_CHECK(op, input_shapes.size() == 2); -template <> -inline void dynamic_inference(const ov::PartialShape& input_shape, - ov::PartialShape& output_shape, - bool keep_dims) { - output_shape = keep_dims ? ov::PartialShape::dynamic(input_shape.rank()) : ov::PartialShape::dynamic(); -} + const auto& data_shape = input_shapes[0]; + const auto& data_rank = data_shape.rank(); + const auto& axes_shape = input_shapes[1]; + const auto& axes_rank = axes_shape.rank(); -template -void reduce_shape_infer(const ov::op::util::ReductionBase* op, - bool keep_dims, - const T& input_shape, - T& output_shape, - const std::map>& constant_data = {}) { - const auto& data_rank = input_shape.rank(); - std::vector axes_val; - bool axes_are_known = get_data_as_int64(1, op, axes_val, constant_data); + std::vector output_shapes; + output_shapes.reserve(1); - if (data_rank.is_static() && axes_are_known) { + NODE_VALIDATION_CHECK(op, + axes_rank.compatible(0) || axes_rank.compatible(1), + "Axes input must be a scalar or 1D input. Got: ", + axes_shape); + + const auto axes_val = ov::op::get_input_const_data_as(op, 1, tensor_accessor); + + if (data_rank.is_static() && axes_val) { OPENVINO_SUPPRESS_DEPRECATED_START - ov::normalize_axes(op, data_rank.get_length(), axes_val); + ov::normalize_axes(op, data_rank.get_length(), *axes_val); OPENVINO_SUPPRESS_DEPRECATED_END if (keep_dims) { - output_shape = input_shape; - for (const auto& axis : axes_val) + output_shapes.push_back(data_shape); + TShape& output_shape = output_shapes[0]; + for (const auto& axis : *axes_val) { output_shape[axis] = 1; - return; + } + } else { + output_shapes.resize(1); + TShape& output_shape = output_shapes[0]; + for (size_t i = 0; i < data_shape.size(); ++i) { + if (std::find(axes_val->begin(), axes_val->end(), i) == axes_val->end()) { + output_shape.push_back(data_shape[i]); + } + } } - for (int64_t i = 0; i < data_rank.get_length(); ++i) - if (find(axes_val.begin(), axes_val.end(), i) == axes_val.end()) - output_shape.push_back(input_shape[i]); } else { - dynamic_inference(input_shape, output_shape, keep_dims); + if (keep_dims) { + output_shapes.push_back(ov::PartialShape::dynamic(data_shape.rank())); + } else { + output_shapes.push_back(ov::PartialShape::dynamic()); + } } + return output_shapes; } -template +// API: TensorAccessor to constant data +template +std::vector shape_infer(const ov::op::util::ArithmeticReductionKeepDims* op, + const std::vector& input_shapes, + const ov::ITensorAccessor& tensor_accessor = ov::make_tensor_accessor()) { + return reduce_shape_infer(op, op->get_keep_dims(), input_shapes, tensor_accessor); +} + +template +std::vector shape_infer(const ov::op::util::LogicalReductionKeepDims* op, + const std::vector& input_shapes, + const ov::ITensorAccessor& tensor_accessor = ov::make_tensor_accessor()) { + return reduce_shape_infer(op, op->get_keep_dims(), input_shapes, tensor_accessor); +} + +// API for compatibility: Constant data map +template void shape_infer(const ov::op::util::ArithmeticReductionKeepDims* op, - const std::vector& input_shapes, - std::vector& output_shapes, + const std::vector& input_shapes, + std::vector& output_shapes, const std::map>& constant_data = {}) { - NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1); - reduce_shape_infer(op, op->get_keep_dims(), input_shapes[0], output_shapes[0], constant_data); + const auto tensor_accessor = ov::make_tensor_accessor(constant_data); + output_shapes = reduce_shape_infer(op, op->get_keep_dims(), input_shapes, tensor_accessor); } -template +template void shape_infer(const ov::op::util::LogicalReductionKeepDims* op, - const std::vector& input_shapes, - std::vector& output_shapes, + const std::vector& input_shapes, + std::vector& output_shapes, const std::map>& constant_data = {}) { - NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1); - reduce_shape_infer(op, op->get_keep_dims(), input_shapes[0], output_shapes[0], constant_data); + const auto tensor_accessor = ov::make_tensor_accessor(constant_data); + output_shapes = reduce_shape_infer(op, op->get_keep_dims(), input_shapes, tensor_accessor); } diff --git a/src/core/src/op/reduce_max.cpp b/src/core/src/op/reduce_max.cpp index d6faa03bac7..8504a7810de 100644 --- a/src/core/src/op/reduce_max.cpp +++ b/src/core/src/op/reduce_max.cpp @@ -4,6 +4,7 @@ #include +#include "bound_evaluate.hpp" #include "itt.hpp" #include "ngraph/graph_util.hpp" #include "ngraph/op/max.hpp" @@ -85,3 +86,11 @@ bool op::v1::ReduceMax::has_evaluate() const { } return false; } + +bool op::v1::ReduceMax::evaluate_lower(ov::TensorVector& output_values) const { + return input_value(1).get_tensor().has_and_set_bound() && default_lower_bound_evaluator(this, output_values); +} + +bool op::v1::ReduceMax::evaluate_upper(ov::TensorVector& output_values) const { + return input_value(1).get_tensor().has_and_set_bound() && default_upper_bound_evaluator(this, output_values); +} diff --git a/src/core/src/op/util/arithmetic_reductions_keep_dims.cpp b/src/core/src/op/util/arithmetic_reductions_keep_dims.cpp index 7380530b4e9..398e7d890da 100644 --- a/src/core/src/op/util/arithmetic_reductions_keep_dims.cpp +++ b/src/core/src/op/util/arithmetic_reductions_keep_dims.cpp @@ -28,7 +28,6 @@ void ov::op::util::ArithmeticReductionKeepDims::validate_and_infer_types() { OV_OP_SCOPE(v0_util_ArithmeticReductionKeepDims_validate_and_infer_types); const element::Type& data_et = get_input_element_type(0); - const PartialShape& axes_shape = get_input_partial_shape(1); const element::Type& axes_et = get_input_element_type(1); NODE_VALIDATION_CHECK(this, @@ -41,12 +40,6 @@ void ov::op::util::ArithmeticReductionKeepDims::validate_and_infer_types() { "Element type of axes input must be integer. Got: ", axes_et); - const Rank axes_rank = axes_shape.rank(); - NODE_VALIDATION_CHECK(this, - axes_rank.compatible(0) || axes_rank.compatible(1), - "Axes input must be a scalar or 1D input. Got: ", - axes_shape); - PartialShape result_shape = infer_reduction_output_shape(m_keep_dims); set_input_is_relevant_to_shape(1); set_output_type(0, data_et, result_shape); diff --git a/src/core/src/op/util/logical_reduction_keep_dims.cpp b/src/core/src/op/util/logical_reduction_keep_dims.cpp index 6d662426466..9ae830f93c0 100644 --- a/src/core/src/op/util/logical_reduction_keep_dims.cpp +++ b/src/core/src/op/util/logical_reduction_keep_dims.cpp @@ -27,7 +27,6 @@ void ov::op::util::LogicalReductionKeepDims::validate_and_infer_types() { OV_OP_SCOPE(v0_util_LogicalReductionKeepDims_validate_and_infer_types); const element::Type& data_et = get_input_element_type(0); - const PartialShape& axes_shape = get_input_partial_shape(1); const element::Type& axes_et = get_input_element_type(1); NODE_VALIDATION_CHECK(this, data_et.compatible(element::boolean), "Element type of data input must be boolean."); @@ -37,12 +36,6 @@ void ov::op::util::LogicalReductionKeepDims::validate_and_infer_types() { "Element type of axes input must be integer. Got: ", axes_et); - const Rank axes_rank = axes_shape.rank(); - NODE_VALIDATION_CHECK(this, - axes_rank.compatible(0) || axes_rank.compatible(1), - "Axes input must be a scalar or 1D input. Got: ", - axes_shape); - PartialShape result_shape = infer_reduction_output_shape(m_keep_dims); set_input_is_relevant_to_shape(1); set_output_type(0, data_et, result_shape); diff --git a/src/core/src/op/util/reduction_base.cpp b/src/core/src/op/util/reduction_base.cpp index 72e3a4c2e8a..f00b01371af 100644 --- a/src/core/src/op/util/reduction_base.cpp +++ b/src/core/src/op/util/reduction_base.cpp @@ -15,9 +15,10 @@ ov::op::util::ReductionBase::ReductionBase(const Output& arg, const Output : Op({arg, reduction_axes}) {} ov::PartialShape ov::op::util::ReductionBase::infer_reduction_output_shape(const bool keep_dims) { - ov::PartialShape output_shape; - reduce_shape_infer(this, keep_dims, get_input_partial_shape(0), output_shape); - return output_shape; + return reduce_shape_infer(this, + keep_dims, + std::vector{get_input_partial_shape(0), get_input_partial_shape(1)}) + .front(); } bool ov::op::util::ReductionBase::reduction_axes_constant() const { diff --git a/src/core/tests/type_prop/reduce_max.cpp b/src/core/tests/type_prop/reduce_max.cpp index 9174b729a99..3bde75d3f45 100644 --- a/src/core/tests/type_prop/reduce_max.cpp +++ b/src/core/tests/type_prop/reduce_max.cpp @@ -7,3 +7,14 @@ using Type = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_reduce_max, ReduceTest, Type); INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_reduce_max_et, ReduceArithmeticTest, Type); + +TEST(type_prop, reduce_max_value_propagation) { + const auto param = std::make_shared(element::f32, PartialShape{{1, 8}, {2, 3}, 6}); + const auto shape_of = std::make_shared(param); + const auto reduce_prod = + std::make_shared(shape_of, op::Constant::create(element::i64, {1}, {0}), true); + const auto reshape = std::make_shared(param, reduce_prod, false); + + EXPECT_EQ(reshape->get_element_type(), ov::element::f32); + EXPECT_EQ(reshape->get_output_partial_shape(0), (PartialShape{{6, 8}})); +} diff --git a/src/core/tests/type_prop/reduce_min.cpp b/src/core/tests/type_prop/reduce_min.cpp index 1198fb48161..eecd4a27b11 100644 --- a/src/core/tests/type_prop/reduce_min.cpp +++ b/src/core/tests/type_prop/reduce_min.cpp @@ -7,3 +7,14 @@ using Type = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_reduce_min, ReduceTest, Type); INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_reduce_min_et, ReduceArithmeticTest, Type); + +TEST(type_prop, reduce_min_value_propagation) { + const auto param = std::make_shared(element::f32, PartialShape{{1, 8}, {2, 3}, 6}); + const auto shape_of = std::make_shared(param); + const auto reduce_prod = + std::make_shared(shape_of, op::Constant::create(element::i64, {1}, {0}), true); + const auto reshape = std::make_shared(param, reduce_prod, false); + + EXPECT_EQ(reshape->get_element_type(), ov::element::f32); + EXPECT_EQ(reshape->get_output_partial_shape(0), (PartialShape{{1, 3}})); +} diff --git a/src/core/tests/type_prop/reduce_ops.hpp b/src/core/tests/type_prop/reduce_ops.hpp index 28ad1607032..2ed6d51ef6e 100644 --- a/src/core/tests/type_prop/reduce_ops.hpp +++ b/src/core/tests/type_prop/reduce_ops.hpp @@ -2,7 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "gtest/gtest.h" +#include "common_test_utils/test_assertions.hpp" +#include "gmock/gmock.h" #include "ngraph/ngraph.hpp" #include "util/type_prop.hpp" @@ -38,6 +39,31 @@ class ReduceTest : public testing::Test {}; TYPED_TEST_SUITE_P(ReduceTest); +TYPED_TEST_P(ReduceTest, reduce_default_ctor) { + PartialShape data_ps{3, 4, 5}; + element::Type data_et = element::dynamic; + + Shape axes_ps{2}; + element::Type axes_et = element::i64; + std::vector axes{1, 2}; + + bool keep_dims = true; + + const auto data = make_shared(data_et, data_ps); + const auto in_axes = make_shared(axes_et, axes_ps); + + auto op = std::make_shared(); + op->set_arguments(OutputVector{data, in_axes}); + op->set_keep_dims(keep_dims); + op->validate_and_infer_types(); + + EXPECT_EQ(op->get_input_size(), 2); + EXPECT_EQ(op->get_output_size(), 1); + + EXPECT_EQ(op->get_keep_dims(), keep_dims); + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic(3)); +} + TYPED_TEST_P(ReduceTest, reduce_basic_shape_infer) { PartialShape data_ps{3, 4, 5}; element::Type data_et = element::dynamic; @@ -72,6 +98,40 @@ TYPED_TEST_P(ReduceTest, reduce_basic_shape_infer_keep_dims) { ASSERT_EQ(reduce_op->get_output_partial_shape(0), out_ps); } +TYPED_TEST_P(ReduceTest, reduce_basic_shape_infer_duplicated_axes) { + PartialShape data_ps{3, 4, 5}; + element::Type data_et = element::dynamic; + + Shape axes_ps{2}; + element::Type axes_et = element::i64; + std::vector axes{1, 1}; + + bool keep_dims = false; + + PartialShape out_ps{3, 5}; + + const ReduceParams params{data_ps, data_et, axes_ps, axes, axes_et, keep_dims}; + auto reduce_op = makeReduceOp(params); + EXPECT_EQ(reduce_op->get_output_partial_shape(0), out_ps); +} + +TYPED_TEST_P(ReduceTest, reduce_basic_shape_infer_keep_dims_duplicated_axes) { + PartialShape data_ps{3, 4, 5}; + element::Type data_et = element::dynamic; + + Shape axes_ps{2}; + element::Type axes_et = element::i64; + std::vector axes{1, 1}; + + bool keep_dims = true; + + PartialShape out_ps{3, 1, 5}; + + const ReduceParams params{data_ps, data_et, axes_ps, axes, axes_et, keep_dims}; + auto reduce_op = makeReduceOp(params); + EXPECT_EQ(reduce_op->get_output_partial_shape(0), out_ps); +} + TYPED_TEST_P(ReduceTest, reduce_basic_shape_infer_scalar_axis) { PartialShape data_ps{3, 4, 5}; element::Type data_et = element::dynamic; @@ -192,6 +252,62 @@ TYPED_TEST_P(ReduceTest, reduce_dynamic_shape_data) { ASSERT_EQ(reduce_op->get_output_partial_shape(0), out_ps); } +TYPED_TEST_P(ReduceTest, dynamic_interval_labeled_shape_data_axes_const) { + using namespace testing; + + PartialShape data_ps{-1, -1, 1, 1, 6, 16, {-1, 8}, {-1, 18}, {4, -1}, {14, -1}, {3, 9}, {13, 19}}; + element::Type data_et = element::dynamic; + + set_shape_labels(data_ps, 10); + + Shape axes_ps{6}; + element::Type axes_et = element::i64; + std::vector axes{1, 3, 5, 7, 9, 11}; + + bool keep_dims = false; + + PartialShape out_ps{-1, 1, 6, {-1, 8}, {4, -1}, {3, 9}}; + + const ReduceParams params{data_ps, data_et, axes_ps, axes, axes_et, keep_dims}; + auto reduce_op = makeReduceOp(params); + EXPECT_EQ(reduce_op->get_output_partial_shape(0), out_ps); + EXPECT_THAT(get_shape_labels(reduce_op->get_output_partial_shape(0)), ElementsAre(10, 12, 14, 16, 18, 20)); +} + +TYPED_TEST_P(ReduceTest, dynamic_interval_labeled_shape_data_axes_const_keep_dims) { + using namespace testing; + + PartialShape data_ps{-1, -1, 1, 1, 6, 16, {-1, 8}, {-1, 18}, {4, -1}, {14, -1}, {3, 9}, {13, 19}}; + element::Type data_et = element::dynamic; + + set_shape_labels(data_ps, 10); + + Shape axes_ps{6}; + element::Type axes_et = element::i64; + std::vector axes{1, 3, 5, 7, 9, 11}; + + bool keep_dims = true; + + PartialShape out_ps{-1, 1, 1, 1, 6, 1, {-1, 8}, 1, {4, -1}, 1, {3, 9}, 1}; + + const ReduceParams params{data_ps, data_et, axes_ps, axes, axes_et, keep_dims}; + auto reduce_op = makeReduceOp(params); + EXPECT_EQ(reduce_op->get_output_partial_shape(0), out_ps); + EXPECT_THAT(get_shape_labels(reduce_op->get_output_partial_shape(0)), + ElementsAre(10, + ov::no_label, + 12, + ov::no_label, + 14, + ov::no_label, + 16, + ov::no_label, + 18, + ov::no_label, + 20, + ov::no_label)); +} + TYPED_TEST_P(ReduceTest, reduce_invalid_axis_out_of_range) { PartialShape data_ps{1, 2, 3}; element::Type data_et = element::dynamic; @@ -256,8 +372,11 @@ TYPED_TEST_P(ReduceTest, reduce_invalid_axes_et) { } REGISTER_TYPED_TEST_SUITE_P(ReduceTest, + reduce_default_ctor, reduce_basic_shape_infer, reduce_basic_shape_infer_keep_dims, + reduce_basic_shape_infer_duplicated_axes, + reduce_basic_shape_infer_keep_dims_duplicated_axes, reduce_basic_shape_infer_scalar_axis, reduce_basic_shape_infer_axes_as_param, reduce_dynamic_shape_data, @@ -265,6 +384,8 @@ REGISTER_TYPED_TEST_SUITE_P(ReduceTest, reduce_dynamic_shape_reduced_axes_static_keep_dims, reduce_dynamic_shape_reduced_axes_not_static, reduce_dynamic_shape_reduced_axes_not_static_keep_dims, + dynamic_interval_labeled_shape_data_axes_const_keep_dims, + dynamic_interval_labeled_shape_data_axes_const, reduce_invalid_axis_out_of_range, reduce_invalid_axes_shape, reduce_invalid_axes_et); diff --git a/src/core/tests/type_prop/reduce_prod.cpp b/src/core/tests/type_prop/reduce_prod.cpp index 67769035cb3..f66e8c44d31 100644 --- a/src/core/tests/type_prop/reduce_prod.cpp +++ b/src/core/tests/type_prop/reduce_prod.cpp @@ -7,3 +7,14 @@ using Type = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_reduce_prod, ReduceTest, Type); INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_reduce_prod_et, ReduceArithmeticTest, Type); + +TEST(type_prop, reduce_prod_value_propagation) { + const auto param = std::make_shared(element::f32, PartialShape{{1, 8}, {2, 3}, 6}); + const auto shape_of = std::make_shared(param); + const auto reduce_prod = + std::make_shared(shape_of, op::Constant::create(element::i64, {1}, {0}), true); + const auto reshape = std::make_shared(param, reduce_prod, false); + + EXPECT_EQ(reshape->get_element_type(), ov::element::f32); + EXPECT_EQ(reshape->get_output_partial_shape(0), (PartialShape{{12, 144}})); +} diff --git a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp index 476de5d8c42..77a6155084c 100644 --- a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp @@ -569,15 +569,6 @@ const IShapeInferCommonFactory::TRegistry IShapeInferCommonFactory::registry{ _OV_OP_SHAPE_INFER_REG(Unsqueeze, entryIOC), _OV_OP_SHAPE_INFER_REG(VariadicSplit, entryIOC), _OV_OP_SHAPE_INFER_VA_REG(Gather, entryIOC, ov::op::util::GatherBase), - _OV_OP_SHAPE_INFER_VA_REG(ReduceL1, entryIOC, op::util::ArithmeticReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceL2, entryIOC, op::util::ArithmeticReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceLogicalAnd, entryIOC, op::util::LogicalReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceLogicalOr, entryIOC, op::util::LogicalReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceMax, entryIOC, op::util::ArithmeticReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceMean, entryIOC, op::util::ArithmeticReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceMin, entryIOC, op::util::ArithmeticReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceProd, entryIOC, op::util::ArithmeticReductionKeepDims), - _OV_OP_SHAPE_INFER_VA_REG(ReduceSum, entryIOC, op::util::ArithmeticReductionKeepDims), // opset7 _OV_OP_SHAPE_INFER_VA_REG(opset7::Gather, entryIOC, ov::op::util::GatherBase), // opset5 @@ -615,6 +606,15 @@ const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{ // Default opset _OV_OP_SHAPE_INFER_MASK_REG(ExperimentalDetectronROIFeatureExtractor, ShapeInferTA, util::bit::mask()), _OV_OP_SHAPE_INFER_MASK_REG(Proposal, ShapeInferTA, util::bit::mask()), + _OV_OP_SHAPE_INFER_VA_REG(ReduceL1, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceL2, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceLogicalAnd, ShapeInferTA, op::util::LogicalReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceLogicalOr, ShapeInferTA, op::util::LogicalReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceMax, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceMean, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceMin, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceProd, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_VA_REG(ReduceSum, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)), _OV_OP_SHAPE_INFER_MASK_REG(Tile, ShapeInferTA, util::bit::mask(1)), // Operators shape inferences for specific opset version should be specified below // opset1 diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_shape_inference_test.cpp new file mode 100644 index 00000000000..94f0c852b73 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_shape_inference_test.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "utils.hpp" +#include "reduce_shape_inference.hpp" + +using namespace ov; +using namespace ov::intel_cpu; +using namespace testing; + +template +class ReduceStaticShapeInferenceTest : public OpStaticShapeInferenceTest { +protected: + void SetUp() override { + this->output_shapes = ShapeVector(1); + } +}; + +TYPED_TEST_SUITE_P(ReduceStaticShapeInferenceTest); + +TYPED_TEST_P(ReduceStaticShapeInferenceTest, default_ctor) { + this->op = this->make_op(); + this->op->set_keep_dims(true); + this->input_shapes = ShapeVector{{1, 6, 7, 8, 4}, {3}}; + + int32_t axes_val[] = {0, 1, 3}; + const std::map>& constant_data = { + {1, std::make_shared(element::i32, Shape{3}, axes_val)}}; + shape_inference(this->op.get(), this->input_shapes, this->output_shapes, constant_data); + + EXPECT_EQ(this->output_shapes.size(), 1); + EXPECT_EQ(this->output_shapes.front(), StaticShape({1, 1, 7, 1, 4})); +} + +TYPED_TEST_P(ReduceStaticShapeInferenceTest, axes_constant) { + const auto data = std::make_shared(element::dynamic, PartialShape{-1, -1, -1, -1}); + const auto axes = std::make_shared(element::i32, Shape{2}, std::vector{1, 3}); + + this->op = this->make_op(data, axes, false); + this->input_shapes = {StaticShape{3, 6, 5, 8}, StaticShape{2}}; + + shape_inference(this->op.get(), this->input_shapes, this->output_shapes); + + EXPECT_EQ(this->output_shapes.size(), 1); + EXPECT_EQ(this->output_shapes.front(), StaticShape({3, 5})); +} + +TYPED_TEST_P(ReduceStaticShapeInferenceTest, axes_param) { + const auto data = std::make_shared(element::dynamic, PartialShape{-1, -1, -1, -1}); + const auto axes = std::make_shared(element::i32, Shape{2}); + + this->op = this->make_op(data, axes, false); + this->input_shapes = {StaticShape{3, 6, 5, 8}, StaticShape{2}}; + + int32_t axes_val[] = {1, 3}; + const std::map>& constant_data = { + {1, std::make_shared(element::i32, Shape{2}, axes_val)}}; + shape_inference(this->op.get(), this->input_shapes, this->output_shapes, constant_data); + + EXPECT_EQ(this->output_shapes.size(), 1); + EXPECT_EQ(this->output_shapes.front(), StaticShape({3, 5})); +} + +TYPED_TEST_P(ReduceStaticShapeInferenceTest, axes_constant_keep_dims) { + const auto data = std::make_shared(element::dynamic, PartialShape{-1, -1, -1, -1}); + const auto axes = std::make_shared(element::i32, Shape{2}, std::vector{1, 3}); + + this->op = this->make_op(data, axes, true); + this->input_shapes = {StaticShape{3, 6, 5, 8}, StaticShape{2}}; + + shape_inference(this->op.get(), this->input_shapes, this->output_shapes); + + EXPECT_EQ(this->output_shapes.size(), 1); + EXPECT_EQ(this->output_shapes.front(), StaticShape({3, 1, 5, 1})); +} + +TYPED_TEST_P(ReduceStaticShapeInferenceTest, axes_param_keep_dims) { + const auto data = std::make_shared(element::dynamic, PartialShape{-1, -1, -1, -1}); + const auto axes = std::make_shared(element::i32, Shape{2}); + + this->op = this->make_op(data, axes, true); + this->input_shapes = {StaticShape{3, 6, 5, 8}, StaticShape{2}}; + + int32_t axes_val[] = {1, 3}; + const std::map>& constant_data = { + {1, std::make_shared(element::i32, Shape{2}, axes_val)}}; + shape_inference(this->op.get(), this->input_shapes, this->output_shapes, constant_data); + + EXPECT_EQ(this->output_shapes.size(), 1); + EXPECT_EQ(this->output_shapes.front(), StaticShape({3, 1, 5, 1})); +} + +REGISTER_TYPED_TEST_SUITE_P(ReduceStaticShapeInferenceTest, + default_ctor, + axes_constant, + axes_param, + axes_param_keep_dims, + axes_constant_keep_dims); + +using ReduceOpTypes = + Types; +INSTANTIATE_TYPED_TEST_SUITE_P(shape_inference, ReduceStaticShapeInferenceTest, ReduceOpTypes); diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_test.cpp deleted file mode 100644 index 9dd98765257..00000000000 --- a/src/plugins/intel_cpu/tests/unit/shape_inference_test/reduce_test.cpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include "utils.hpp" - -using namespace ov; -using namespace ov::intel_cpu; - -TEST(StaticShapeInferenceTest, ReduceTest) { - auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); - auto axes = std::make_shared(element::i32, Shape{2}, std::vector{1, 3}); - - auto reduce = - std::make_shared(data, axes, true); - - std::vector static_input_shapes = {StaticShape{3, 6, 5, 5}, StaticShape{2}}, - static_output_shapes = {StaticShape{}}; - shape_inference(reduce.get(), static_input_shapes, static_output_shapes); - - ASSERT_EQ(static_output_shapes[0], StaticShape({3, 1, 5, 1})); -}