diff --git a/src/core/dev_api/dimension_tracker.hpp b/src/core/dev_api/dimension_tracker.hpp index 5023e584697..398abe995aa 100644 --- a/src/core/dev_api/dimension_tracker.hpp +++ b/src/core/dev_api/dimension_tracker.hpp @@ -12,6 +12,9 @@ namespace ov { +/// \brief Special label value indicate no label set. +constexpr size_t no_label = 0; + /// \brief Friend class of Dimension to set, get and track dimensions and their equivalence class DimensionTracker { public: @@ -22,7 +25,7 @@ public: }; static void set_label(ov::Dimension& d, size_t label) { - OPENVINO_ASSERT(label != 0, "Can not set zero as label for dimension -- it is reserved for no label"); + OPENVINO_ASSERT(label != no_label, "Can not set zero as label for dimension -- it is reserved for no label"); d.m_label = label; } @@ -47,7 +50,7 @@ public: } static void reset_tracking_info(ov::Dimension& d) { - d.m_label = 0; + d.m_label = no_label; d.m_table_of_equivalence = nullptr; } diff --git a/src/core/include/openvino/core/validation_util.hpp b/src/core/include/openvino/core/validation_util.hpp index f869e4a2c96..50522b8eff9 100644 --- a/src/core/include/openvino/core/validation_util.hpp +++ b/src/core/include/openvino/core/validation_util.hpp @@ -140,7 +140,6 @@ OPENVINO_API bool default_label_evaluator(const Node* node, TensorLabelVector& o /// /// \param axes_order Vector where default order will be generated. /// \param length Sequence length of axes order. -/// OPENVINO_API void generate_transpose_default_order(std::vector& axes_order, const size_t length); /// \brief Check if vector of axes order has got valid values. @@ -151,6 +150,11 @@ OPENVINO_API void generate_transpose_default_order(std::vector& axes_or /// \param size Input for transpose rank size. /// /// \return true if axes order is valid otherwise false. -/// OPENVINO_API bool is_valid_axes_order(const std::vector& axes_order, const size_t size); + +/// \brief Checks label tensor if there is no label +/// +/// \param labels Label tensor for check. +/// \return True if there is no labels, otherwise false. +OPENVINO_API bool has_no_labels(const TensorLabel& labels); } // namespace ov diff --git a/src/core/shape_inference/include/compare.hpp b/src/core/shape_inference/include/compare.hpp index b72ec06c0e2..e11f050f0f9 100644 --- a/src/core/shape_inference/include/compare.hpp +++ b/src/core/shape_inference/include/compare.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -48,5 +48,22 @@ public: return (_lower_bound <= value) && (value <= _upper_bound); } }; + +/** + * \brief Compare if value is equal to expected. + * + * \tparam T Value type to compare. + */ +template +class Equal { + T _exp_value; + +public: + constexpr Equal(const T& exp_value) : _exp_value{exp_value} {} + + constexpr bool operator()(const T& value) const { + return _exp_value == value; + } +}; } // namespace cmp } // namespace ov diff --git a/src/core/shape_inference/include/concat_shape_inference.hpp b/src/core/shape_inference/include/concat_shape_inference.hpp new file mode 100644 index 00000000000..d571b9417c4 --- /dev/null +++ b/src/core/shape_inference/include/concat_shape_inference.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/validation_util.hpp" +#include "openvino/op/concat.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { +namespace v0 { + +template +void shape_infer(const Concat* op, const std::vector& input_shapes, std::vector& output_shapes) { + using DimType = typename std::iterator_traits::value_type; + + const auto concat_axis = op->get_concatenation_axis(); + const auto empty_dim = DimType{}; + + auto concat_dim = DimType{0}; + auto& output_shape = output_shapes.front(); + + if (std::is_same::value) { + output_shape = PartialShape::dynamic(); + } else { + output_shape = input_shapes.front(); + output_shape[concat_axis] = empty_dim; + } + + for (auto input : input_shapes) { + if (input.rank().is_static()) { + concat_dim += input[concat_axis]; + input[concat_axis] = empty_dim; + + NODE_VALIDATION_CHECK(op, + T::merge_into(output_shape, input), + "Argument shapes are inconsistent; they must have the same rank, and must " + "have ", + "equal dimension everywhere except on the concatenation axis (axis ", + concat_axis, + ")."); + } else { + concat_dim += empty_dim; + } + } + + if (output_shape.rank().is_static()) { + output_shape[concat_axis] = concat_dim; + } +} +} // namespace v0 +} // namespace op +} // namespace ov diff --git a/src/core/src/op/concat.cpp b/src/core/src/op/concat.cpp index d6504ff42d8..355f15d4981 100644 --- a/src/core/src/op/concat.cpp +++ b/src/core/src/op/concat.cpp @@ -5,11 +5,13 @@ #include "ngraph/op/concat.hpp" #include -#include +#include "concat_shape_inference.hpp" +#include "dimension_tracker.hpp" #include "itt.hpp" #include "ngraph/attribute_visitor.hpp" #include "ngraph/runtime/reference/concat.hpp" +#include "ngraph/validation_util.hpp" using namespace std; using namespace ngraph; @@ -32,63 +34,48 @@ void op::Concat::validate_and_infer_types() { OV_OP_SCOPE(v0_Concat_validate_and_infer_types); NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required."); - ov::PartialShape inputs_shape_scheme{ov::PartialShape::dynamic()}; element::Type inputs_et{element::dynamic}; - Dimension concatenation_axis_output_dim{0}; + auto input_shapes = std::vector(); - for (uint64_t i = 0; i < get_input_size(); i++) { + for (size_t i = 0; i < get_input_size(); ++i) { NODE_VALIDATION_CHECK(this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)), "Argument element types are inconsistent."); - ov::PartialShape this_input_shape = get_input_partial_shape(i); - Dimension this_input_rank = this_input_shape.rank(); - if (this_input_rank.is_static()) { - if (get_concatenation_axis() < 0) { - set_concatenation_axis(get_axis() < 0 ? get_axis() + this_input_rank.get_length() : get_axis()); - } - auto concat_axis = get_concatenation_axis(); - NODE_VALIDATION_CHECK(this, - concat_axis < this_input_rank.get_length() && concat_axis >= 0, - "Concatenation axis (", - concat_axis, - ") is out of bounds [", - -this_input_rank.get_length(), - ", ", - this_input_rank.get_length() - 1, - "] for ", - "argument ", - i, - ", which has shape ", - this_input_shape, - "."); + const auto& input_shape = get_input_partial_shape(i); + const auto& input_rank = input_shape.rank(); - concatenation_axis_output_dim += this_input_shape[concat_axis]; - this_input_shape[concat_axis] = Dimension::dynamic(); - - NODE_VALIDATION_CHECK(this, - ov::PartialShape::merge_into(inputs_shape_scheme, this_input_shape), - "Argument shapes are inconsistent; they must have the same rank, and must " - "have ", - "equal dimension everywhere except on the concatenation axis (axis ", - concat_axis, - ")."); - } else { - concatenation_axis_output_dim += Dimension::dynamic(); + if (input_rank.is_static() && (get_concatenation_axis() < 0)) { + set_concatenation_axis(get_axis() < 0 ? get_axis() + input_rank.get_length() : get_axis()); } - } - ov::PartialShape concatenated_shape = inputs_shape_scheme; - if (concatenated_shape.rank().is_static()) { - concatenated_shape[get_concatenation_axis()] = concatenation_axis_output_dim; - set_output_type(0, inputs_et, concatenated_shape); - } else { - set_output_type(0, inputs_et, ov::PartialShape::dynamic(concatenation_axis_output_dim)); + const auto concat_axis = get_concatenation_axis(); + + NODE_VALIDATION_CHECK(this, + input_shape.is_dynamic() || (0 <= concat_axis && concat_axis < input_rank.get_length()), + "Concatenation axis (", + concat_axis, + ") is out of bounds [", + -input_rank.get_length(), + ", ", + input_rank.get_length() - 1, + "] for ", + "argument ", + i, + ", which has shape ", + input_shape, + "."); + + input_shapes.push_back(input_shape); } + + std::vector output_shapes(1, PartialShape{}); + + shape_infer(this, input_shapes, output_shapes); + set_output_type(0, inputs_et, output_shapes.front()); } shared_ptr op::Concat::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v0_Concat_clone_with_new_inputs); - // TODO(amprocte): Should we check the new_args count here? return make_shared(new_args, m_axis); } @@ -139,14 +126,12 @@ bool op::Concat::evaluate_upper(const HostTensorVector& output_values) const { bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const { const auto& inputs = input_values(); - bool has_labeled_input = std::any_of(inputs.begin(), inputs.end(), [](const Output& out) { - const auto& labels = out.get_tensor().get_value_label(); - return !labels.empty() && std::any_of(labels.begin(), labels.end(), [](const size_t& l) { - return l > 0; - }); - }); - if (!has_labeled_input) + if (std::all_of(inputs.cbegin(), inputs.cend(), [](const Output& out) { + const auto& labels = out.get_tensor().get_value_label(); + return has_no_labels(labels); + })) { return false; + } HostTensorVector idx_inputs; idx_inputs.reserve(inputs.size()); @@ -157,7 +142,7 @@ bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const { // sanity check. at this point value propagation was successful NGRAPH_CHECK(shape.is_static()); const auto& num_elements = shape_size(shape.to_shape()); - input_label = TensorLabel(num_elements, 0); + input_label.resize(num_elements, no_label); } const auto& constant = Constant::create(element::u64, input.get_shape(), input_label); idx_inputs.push_back(std::make_shared(constant)); @@ -165,7 +150,6 @@ bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const { const auto& output_tensor = std::make_shared(element::u64, get_output_shape(0)); evaluate({output_tensor}, idx_inputs); - const auto& output_idxs = std::make_shared(output_tensor)->cast_vector(); - output_labels[0] = output_idxs; + output_labels[0] = std::make_shared(output_tensor)->cast_vector(); return true; } diff --git a/src/core/src/validation_util.cpp b/src/core/src/validation_util.cpp index 4812f607fa5..10fdc6a7186 100644 --- a/src/core/src/validation_util.cpp +++ b/src/core/src/validation_util.cpp @@ -1345,7 +1345,6 @@ bool ov::default_label_evaluator(const Node* node, TensorLabelVector& output_lab NGRAPH_CHECK(node->outputs().size() == 1); const auto& input_values = node->input_values(); - TensorLabel input_labels; HostTensorVector input_tensors(input_values.size()); for (size_t i = 0; i < input_values.size(); ++i) { @@ -1356,12 +1355,10 @@ bool ov::default_label_evaluator(const Node* node, TensorLabelVector& output_lab else return false; else { - input_labels = input.get_tensor().get_value_label(); - bool no_labels = std::all_of(input_labels.begin(), input_labels.end(), [](const size_t& l) { - return l == 0; - }); - if (input_labels.empty() || no_labels) + const auto& input_labels = input.get_tensor().get_value_label(); + if (has_no_labels(input_labels)) { return false; + } auto labels_constant = op::v0::Constant::create(ov::element::u64, input.get_shape(), input_labels); auto idxs_htp = std::make_shared(labels_constant); @@ -1638,11 +1635,12 @@ shared_ptr ov::get_constant_from_source(const Output& source } bool ngraph::validate_host_tensor_vector(const HostTensorVector& tensor_vector, const size_t& size) { - if (tensor_vector.size() != size) - return false; - return std::all_of(tensor_vector.begin(), tensor_vector.end(), [](const HostTensorPtr& t) { - return t != nullptr; - }); + return (tensor_vector.size() == size) && + std::none_of(tensor_vector.cbegin(), tensor_vector.cend(), ov::cmp::Equal(nullptr)); +} + +bool ov::has_no_labels(const ov::TensorLabel& labels) { + return std::all_of(labels.cbegin(), labels.cend(), cmp::Equal(no_label)); } void ov::generate_transpose_default_order(std::vector& axes_order, const size_t length) { diff --git a/src/core/tests/CMakeLists.txt b/src/core/tests/CMakeLists.txt index 7ecad533d7c..072592ba7af 100644 --- a/src/core/tests/CMakeLists.txt +++ b/src/core/tests/CMakeLists.txt @@ -50,6 +50,7 @@ set(SRC dimension.cpp element_type.cpp eval.cpp + evaluate_bound/concat.cpp evaluate_bound/transpose.cpp extension.cpp file_util.cpp diff --git a/src/core/tests/evaluate_bound/concat.cpp b/src/core/tests/evaluate_bound/concat.cpp new file mode 100644 index 00000000000..241b768cff0 --- /dev/null +++ b/src/core/tests/evaluate_bound/concat.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "dimension_tracker.hpp" +#include "gmock/gmock.h" +#include "openvino/core/dimension.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/partial_shape.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/graph_rewrite.hpp" + +using namespace ov::opset9; +using namespace testing; + +using ShapeVector = std::vector; +using LabeledShape = std::tuple; +using LabeledShapeVector = std::vector; +using TestParams = std::tuple; + +class EvaluateLabelTest : public Test { +protected: + ov::element::Type label_dtype{ov::element::u64}; + ov::TensorLabelVector out_labels; + + bool exp_evaluate_status; + ov::TensorVector exp_result, inputs; + std::vector> labels_u64; //!< Storage for tensor labels. +}; + +class ConcatEvaluateLabelTest : public EvaluateLabelTest, public WithParamInterface { +protected: + void SetUp() override { + exp_result = ov::TensorVector{ov::Tensor(ov::element::u64, {0})}; + const auto& labeled_shapes = std::get<1>(GetParam()); + + exp_evaluate_status = + std::any_of(labeled_shapes.cbegin(), labeled_shapes.cend(), [](const LabeledShape& l_shape) { + return std::get<1>(l_shape); + }); + + for (const auto& labeled_shape : labeled_shapes) { + ov::PartialShape shape; + bool add_labels; + std::tie(shape, add_labels) = labeled_shape; + + auto param = params.make(ov::element::u64, shape); + + if (exp_evaluate_status) { + auto min_shape = shape.get_min_shape(); + ov::TensorLabel labels(ov::shape_size(min_shape), ov::no_label); + + if (add_labels) { + std::iota(labels.begin(), labels.end(), 1); + param->get_default_output().get_tensor().set_value_label(labels); + } + + labels_u64.emplace_back(std::vector(labels.cbegin(), labels.cend())); + inputs.emplace_back(label_dtype, min_shape, labels_u64.back().data()); + } + } + } + + std::shared_ptr concat; + ov::pass::NodeRegistry params; +}; + +const auto shape1 = ov::PartialShape({3, 2, 1}); +const auto shape2 = ov::PartialShape({3, 4, 1}); + +const auto contactable_shapes_axis_1 = Values( + LabeledShapeVector{std::make_tuple(shape1, false)}, + LabeledShapeVector{std::make_tuple(shape2, false)}, + LabeledShapeVector{std::make_tuple(shape2, false), std::make_tuple(shape1, false)}, + LabeledShapeVector{std::make_tuple(shape1, true), std::make_tuple(shape2, false)}, + LabeledShapeVector{std::make_tuple(shape2, false), std::make_tuple(shape1, true)}, + LabeledShapeVector{std::make_tuple(shape1, true), std::make_tuple(shape2, false), std::make_tuple(shape1, false)}, + LabeledShapeVector{std::make_tuple(shape1, true), std::make_tuple(shape2, false), std::make_tuple(shape2, true)}, + LabeledShapeVector{std::make_tuple(shape1, true), + std::make_tuple(shape2, true), + std::make_tuple(shape2, true), + std::make_tuple(shape1, true)}); + +INSTANTIATE_TEST_SUITE_P(evaluate_bound_contactable_axis_1, + ConcatEvaluateLabelTest, + Combine(Values(1), contactable_shapes_axis_1), + PrintToStringParamName()); + +const auto contactable_shapes = Values( + LabeledShapeVector{std::make_tuple(shape1, false)}, + LabeledShapeVector{std::make_tuple(shape1, false), std::make_tuple(shape1, false)}, + LabeledShapeVector{std::make_tuple(shape2, false), std::make_tuple(shape2, false), std::make_tuple(shape2, true)}, + LabeledShapeVector{std::make_tuple(shape2, true), std::make_tuple(shape2, false), std::make_tuple(shape2, true)}, + LabeledShapeVector{std::make_tuple(shape1, true), std::make_tuple(shape1, true), std::make_tuple(shape1, true)}); + +INSTANTIATE_TEST_SUITE_P(evaluate_bound, + ConcatEvaluateLabelTest, + Combine(testing::Range(-3, 3), contactable_shapes), + PrintToStringParamName()); + +/** \brief Test evaluate label for combination of different shapes and each shape may be labeled. */ +TEST_P(ConcatEvaluateLabelTest, evaluate_label) { + const auto concat = std::make_shared(params.get(), std::get<0>(GetParam())); + out_labels.resize(concat->get_output_size()); + + if (exp_evaluate_status) { + concat->evaluate(exp_result, inputs); + } + + ASSERT_EQ(concat->evaluate_label(out_labels), exp_evaluate_status); + ASSERT_THAT(out_labels.front(), + ElementsAreArray(exp_result.front().data(), exp_result.front().get_size())); +} diff --git a/src/core/tests/type_prop/concat.cpp b/src/core/tests/type_prop/concat.cpp index 541456d7bd7..06f39dbbaa4 100644 --- a/src/core/tests/type_prop/concat.cpp +++ b/src/core/tests/type_prop/concat.cpp @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 // -#include - -#include "gtest/gtest.h" +#include "dimension_tracker.hpp" +#include "gmock/gmock.h" #include "ngraph/ngraph.hpp" +#include "openvino/pass/graph_rewrite.hpp" #include "util/type_prop.hpp" using namespace std; using namespace ngraph; +using namespace testing; TEST(type_prop, concat_deduce) { // Deduce type @@ -17,7 +18,7 @@ TEST(type_prop, concat_deduce) { auto param1 = make_shared(element::f32, Shape{2, 7, 4}); auto param2 = make_shared(element::f32, Shape{2, 2, 4}); auto c = make_shared(NodeVector{param0, param1, param2}, 1); - ASSERT_EQ(c->get_element_type(), element::f32); + EXPECT_EQ(c->get_element_type(), element::f32); ASSERT_EQ(c->get_shape(), (Shape{2, 12, 4})); } @@ -80,7 +81,7 @@ TEST(type_prop, concat_deduce_axis_barely_in_bounds) { auto param1 = make_shared(element::f32, Shape{2, 3, 8}); auto param2 = make_shared(element::f32, Shape{2, 3, 12}); auto c = make_shared(NodeVector{param0, param1, param2}, 2); - ASSERT_EQ(c->get_element_type(), element::f32); + EXPECT_EQ(c->get_element_type(), element::f32); ASSERT_EQ(c->get_shape(), (Shape{2, 3, 24})); } @@ -105,7 +106,7 @@ TEST(type_prop, concat_partial_et_consistent) { auto param2 = make_shared(element::f32, Shape{2, 2, 4}); auto c = make_shared(NodeVector{param0, param1, param2}, 1); - ASSERT_EQ(c->get_element_type(), element::f32); + EXPECT_EQ(c->get_element_type(), element::f32); ASSERT_EQ(c->get_shape(), (Shape{2, 12, 4})); } @@ -124,24 +125,6 @@ TEST(type_prop, concat_partial_et_inconsistent) { } } -TEST(type_prop, concat_partial_all_rank_dynamic) { - auto param0 = make_shared(element::f32, PartialShape::dynamic()); - auto param1 = make_shared(element::f32, PartialShape::dynamic()); - auto param2 = make_shared(element::f32, PartialShape::dynamic()); - auto c = make_shared(NodeVector{param0, param1, param2}, 1); - - ASSERT_TRUE(c->get_output_partial_shape(0).rank().is_dynamic()); -} - -TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_dynamic_consistent) { - auto param0 = make_shared(element::f32, PartialShape{2, Dimension::dynamic(), 3}); - auto param1 = make_shared(element::f32, PartialShape::dynamic()); - auto param2 = make_shared(element::f32, PartialShape{2, 3, Dimension::dynamic()}); - auto c = make_shared(NodeVector{param0, param1, param2}, 1); - - ASSERT_TRUE(c->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3})); -} - TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_dynamic_rank_inconsistent) { auto param0 = make_shared(element::f32, PartialShape{2, Dimension::dynamic(), 3}); auto param1 = make_shared(element::f32, PartialShape::dynamic()); @@ -197,15 +180,6 @@ TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_dynamic_dims } } -TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_with_concat_axis_static) { - auto param0 = make_shared(element::f32, PartialShape{2, 2, 3}); - auto param1 = make_shared(element::f32, PartialShape::dynamic()); - auto param2 = make_shared(element::f32, PartialShape{2, 3, Dimension::dynamic()}); - auto c = make_shared(NodeVector{param0, param1, param2}, 1); - - ASSERT_TRUE(c->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3})); -} - TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_with_concat_axis_static_dims_inconsistent) { auto param0 = make_shared(element::f32, PartialShape{2, 2, 3}); auto param1 = make_shared(element::f32, PartialShape::dynamic()); @@ -234,15 +208,6 @@ TEST(type_prop, concat_partial_all_static_with_concat_axis_static_compatible_res ASSERT_EQ(c->get_shape(), (Shape{2, 9, 3})); } -TEST(type_prop, concat_partial_all_static_with_concat_axis_static_compatible_result_dynamic) { - auto param0 = make_shared(element::f32, PartialShape{2, 2, Dimension::dynamic()}); - auto param1 = make_shared(element::f32, PartialShape{Dimension::dynamic(), 4, Dimension::dynamic()}); - auto param2 = make_shared(element::f32, PartialShape{2, 3, Dimension::dynamic()}); - auto c = make_shared(NodeVector{param0, param1, param2}, 1); - - ASSERT_TRUE(c->get_output_partial_shape(0).same_scheme(PartialShape{2, 9, Dimension::dynamic()})); -} - TEST(type_prop, concat_partial_all_static_with_concat_axis_static_dims_incompatible) { auto param0 = make_shared(element::f32, PartialShape{2, 2, 3}); auto param1 = make_shared(element::f32, PartialShape{Dimension::dynamic(), 4, 3}); @@ -268,7 +233,7 @@ TEST(type_prop, concat_partial_negative_axis_correct) { auto c = make_shared(NodeVector{param0, param1, param2}, -3); - ASSERT_EQ(c->get_element_type(), element::f32); + EXPECT_EQ(c->get_element_type(), element::f32); ASSERT_EQ(c->get_shape(), (Shape{12, 2, 4})); } @@ -288,6 +253,7 @@ TEST(type_prop, concat_partial_negative_axis_incorrect) { } } +/** \brief Test uses evaluate lower/upper and label of concat op. */ TEST(type_prop, concat_dynamic_value_and_label_propagation) { Dimension marked_0 = Dimension(3); ov::DimensionTracker::set_label(marked_0, 10); @@ -308,17 +274,14 @@ TEST(type_prop, concat_dynamic_value_and_label_propagation) { auto target_shape = std::make_shared(OutputVector{shape_0, five, shape_1}, 0); auto bc = make_shared(param, target_shape); - ASSERT_EQ(bc->get_shape(), (Shape{3, 4, 5, 4, 5, 9})); + EXPECT_EQ(bc->get_shape(), (Shape{3, 4, 5, 4, 5, 9})); const auto& output_shape = bc->get_output_partial_shape(0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[1]), 0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[2]), 0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[3]), 0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[4]), 15); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[5]), 0); + const auto labels = get_shape_labels(output_shape); + ASSERT_THAT(labels, ElementsAre(10, 0, 0, 0, 15, 0)); } +/** \brief Test uses evaluate lower/upper and label of concat op. */ TEST(type_prop, concat_dynamic_value_and_label_propagation_1) { Dimension marked_0 = Dimension(3); ov::DimensionTracker::set_label(marked_0, 1000); @@ -343,13 +306,107 @@ TEST(type_prop, concat_dynamic_value_and_label_propagation_1) { auto convert = make_shared(target_shape, element::i64); auto bc = make_shared(param, target_shape); - ASSERT_EQ(bc->get_shape(), (Shape{3, 4, 5, 4, 5, 9})); + EXPECT_EQ(bc->get_shape(), (Shape{3, 4, 5, 4, 5, 9})); const auto& output_shape = bc->get_output_partial_shape(0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 1000); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[1]), 0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[2]), 0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[3]), 0); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[4]), 1500); - ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[5]), 0); + const auto labels = get_shape_labels(output_shape); + ASSERT_THAT(labels, ElementsAre(1000, 0, 0, 0, 1500, 0)); +} + +TEST(type_prop, concat_interval_dimensions) { + auto param0 = make_shared(element::f32, Shape{3, 2, 4}); + auto param1 = make_shared(element::f32, Shape{7, 2, 4}); + auto param2 = make_shared(element::f32, Shape{2, 2, 4}); + + auto c = make_shared(NodeVector{param0, param1, param2}, -3); + + EXPECT_EQ(c->get_element_type(), element::f32); + ASSERT_EQ(c->get_shape(), (Shape{12, 2, 4})); +} + +using PartialShapeVector = std::vector; +using ConcatTestParams = std::tuple>; + +class ConcatTest : public TestWithParam { +protected: + void SetUp() override { + int64_t axis; + PartialShapeVector input_shapes; + ov::pass::NodeRegistry params; + + std::forward_as_tuple(input_shapes, std::tie(axis, exp_shape)) = GetParam(); + + for (const auto& shape : input_shapes) { + params.make(element::f32, shape); + } + + c = make_shared(params.get(), axis); + } + + PartialShape exp_shape; + std::shared_ptr c; +}; + +const auto shapes_with_interval_dim = Values(PartialShapeVector{(PartialShape::dynamic()), + {2, Dimension(2, 5), 3, 1}, + {2, 4, 3, Dimension(1, 4)}, + {2, 4, 3, 1}}); + +INSTANTIATE_TEST_SUITE_P(type_prop_interval_dim_mixed_ranks, + ConcatTest, + Combine(shapes_with_interval_dim, + Values(std::make_tuple(1, PartialShape({2, Dimension(10, -1), 3, 1})), // axis 1 + std::make_tuple(-1, PartialShape({2, 4, 3, Dimension(3, -1)})), // axis 2 + std::make_tuple(2, PartialShape({2, 4, Dimension(9, -1), 1})) // axis 3 + )), + PrintToStringParamName()); + +const auto shapes_all_dynamic_ranks = Values(PartialShapeVector{(PartialShape::dynamic()), + (PartialShape::dynamic()), + (PartialShape::dynamic()), + (PartialShape::dynamic())}); + +INSTANTIATE_TEST_SUITE_P(type_prop_dynamic_ranks_against_axis_range, + ConcatTest, + Combine(shapes_all_dynamic_ranks, + Combine(Range(-4, 4), Values(PartialShape::dynamic()))), + PrintToStringParamName()); + +const auto shapes_static_dynamic_ranks = + Values(PartialShapeVector{PartialShape({4, 2, Dimension::dynamic(), 3}), + PartialShape::dynamic(), + PartialShape({4, 2, Dimension::dynamic(), Dimension::dynamic()})}); + +INSTANTIATE_TEST_SUITE_P(type_prop_mixed_ranks_and_dims, + ConcatTest, + Combine(shapes_static_dynamic_ranks, + Values( + // concat all dynamic dims + std::make_tuple(2, PartialShape({4, 2, Dimension::dynamic(), 3})), + // concat dynamic and interval dim + std::make_tuple(1, PartialShape({4, Dimension(4, -1), Dimension::dynamic(), 3})))), + PrintToStringParamName()); + +INSTANTIATE_TEST_SUITE_P(type_prop_1d_shapes, + ConcatTest, + Values( + // concat all dynamic dims + std::make_tuple(PartialShapeVector{{-1}, {-1}, {-1}}, + std::make_tuple(0, PartialShape({-1}))), + // concat dynamic and not matching static dims + std::make_tuple(PartialShapeVector{{3}, PartialShape::dynamic(), {2}}, + std::make_tuple(0, PartialShape({Dimension(5, -1)}))), + // concat all static dim + std::make_tuple(PartialShapeVector{{3}, {3}, {3}}, std::make_tuple(0, PartialShape({9}))), + // concat dynamic and interval dim + std::make_tuple(PartialShapeVector{{3}, {Dimension::dynamic()}, {Dimension(3, 4)}}, + std::make_tuple(0, PartialShape({Dimension(6, -1)})))), + PrintToStringParamName()); + +/** \brief Shape propagation no exception. */ +TEST_P(ConcatTest, partial_shape_propagation) { + ASSERT_EQ(c->get_default_output().get_partial_shape(), exp_shape); } diff --git a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp index 4050ddab35b..7b48ab743a3 100644 --- a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp @@ -1,8 +1,6 @@ // Copyright (C) 2018-2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // -#include "shape_inference.hpp" - #include #include #include @@ -17,6 +15,7 @@ #include "batch_to_space_shape_inference.hpp" #include "broadcast_shape_inference.hpp" #include "bucketize_shape_inference.hpp" +#include "concat_shape_inference.hpp" #include "convolution_shape_inference.hpp" #include "ctc_greedy_decoder_seq_len_shape_inference.hpp" #include "ctc_greedy_decoder_shape_inference.hpp" @@ -56,6 +55,7 @@ #include "scatter_elements_update_shape_inference.hpp" #include "scatter_nd_base_shape_inference.hpp" #include "select_shape_inference.hpp" +#include "shape_inference.hpp" #include "shape_nodes.hpp" #include "shuffle_channels_shape_inference.hpp" #include "space_to_batch_shape_inference.hpp" @@ -554,6 +554,8 @@ std::shared_ptr make_shape_inference(const std::shared_ptr>(node); } else if (auto node = ov::as_type_ptr(op)) { return make_shared_entryIOC(node); + } else if (auto node = ov::as_type_ptr(op)) { + return make_shared_entryIO(node); } else { return std::make_shared(op); } diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/concat_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/concat_shape_inference_test.cpp new file mode 100644 index 00000000000..cf5576a49bf --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/concat_shape_inference_test.cpp @@ -0,0 +1,74 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "concat_shape_inference.hpp" +#include "gtest/gtest.h" +#include "openvino/op/parameter.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "utils/shape_inference/static_shape.hpp" + +using namespace ov; +using namespace ov::intel_cpu; +using namespace testing; + +using ShapeVector = std::vector; +using TestParams = std::tuple; + +class ConcatStaticShapeInferenceTest : public TestWithParam { +protected: + void SetUp() override { + std::tie(concat_axis, input_shapes, exp_shape) = GetParam(); + + for (const auto& in : input_shapes) { + params.make(element::f32, in.get_shape()); + } + concat = std::make_shared(params.get(), concat_axis); + } + + int64_t concat_axis; + StaticShape exp_shape; + ShapeVector input_shapes; + + pass::NodeRegistry params{}; + std::shared_ptr concat; +}; + +/** \brief Concatenate simple 1d shapes. */ +INSTANTIATE_TEST_SUITE_P(concat_1d_shapes, + ConcatStaticShapeInferenceTest, + Values(make_tuple(0, ShapeVector{{0}}, StaticShape({0})), + make_tuple(0, ShapeVector{{3}}, StaticShape({3})), + make_tuple(0, ShapeVector{{1}, {1}}, StaticShape({2})), + make_tuple(0, ShapeVector{{1}, {3}}, StaticShape({4})), + make_tuple(0, ShapeVector{{4}, {1}}, StaticShape({5})), + make_tuple(0, ShapeVector{{4}, {0}}, StaticShape({4})), + make_tuple(-1, ShapeVector{{4}, {0}, {2}}, StaticShape({6})), + make_tuple(-1, ShapeVector{{2}, {7}, {3}}, StaticShape({12}))), + PrintToStringParamName()); + +/** \brief Concatenate complex shapes. */ +INSTANTIATE_TEST_SUITE_P( + concat_complex_shapes, + ConcatStaticShapeInferenceTest, + Values(make_tuple(1, ShapeVector{{0, 0}}, StaticShape({0, 0})), + make_tuple(1, ShapeVector{{3, 1}, {3, 2}}, StaticShape({3, 3})), + make_tuple(0, ShapeVector{{3, 1, 2}, {3, 1, 2}}, StaticShape({6, 1, 2})), + make_tuple(-3, ShapeVector{{3, 1, 2}, {3, 1, 2}}, StaticShape({6, 1, 2})), + make_tuple(2, ShapeVector{{3, 1, 2}, {3, 1, 2}}, StaticShape({3, 1, 4})), + make_tuple(-2, ShapeVector{{3, 1, 2}, {3, 1, 2}}, StaticShape({3, 2, 2})), + make_tuple(-1, ShapeVector{{2, 5, 1, 1}, {2, 5, 1, 2}, {2, 5, 1, 2}}, StaticShape({2, 5, 1, 5})), + make_tuple(2, ShapeVector{{2, 5, 6, 2}, {2, 5, 7, 2}, {2, 5, 1, 2}}, StaticShape({2, 5, 14, 2}))), + PrintToStringParamName()); + +/** \brief Check shape_infer for concat op on static shapes. */ +TEST_P(ConcatStaticShapeInferenceTest, concat_static) { + auto output_shapes = ShapeVector(1); + + shape_infer(concat.get(), input_shapes, output_shapes); + + ASSERT_EQ(output_shapes.front(), exp_shape); +}