diff --git a/docs/template_plugin/tests/functional/op_reference/loop.cpp b/docs/template_plugin/tests/functional/op_reference/loop.cpp index 9d39725efad..290593b6794 100644 --- a/docs/template_plugin/tests/functional/op_reference/loop.cpp +++ b/docs/template_plugin/tests/functional/op_reference/loop.cpp @@ -9,17 +9,30 @@ #include "base_reference_test.hpp" #include "functional_test_utils/skip_tests_config.hpp" +#include "common_test_utils/common_utils.hpp" + +namespace { +enum LOOP_IN_TYPE { + INVARIANT, + MERGED +}; struct LoopFunctionalBase { virtual std::shared_ptr create_function(const std::vector& loop_inputs, - const std::vector& results) = 0; + const std::vector& results, + const int64_t& trip_count_value = 1, + const std::vector& loop_in_type = {}, + const ov::element::Type& net_type = ov::element::f32) = 0; LoopFunctionalBase() = default; virtual ~LoopFunctionalBase() = default; }; struct LoopDynamicInputs : public LoopFunctionalBase { std::shared_ptr create_function(const std::vector& loop_inputs, - const std::vector& results) override { + const std::vector& results, + const int64_t& trip_count_value, + const std::vector& loop_in_type, + const ov::element::Type& net_type) override { auto X = std::make_shared(ov::element::f32, ov::PartialShape::dynamic()); auto Y = std::make_shared(ov::element::f32, ov::PartialShape::dynamic()); auto M = std::make_shared(ov::element::f32, ov::PartialShape::dynamic()); @@ -113,3 +126,294 @@ INSTANTIATE_TEST_SUITE_P( reference_tests::Tensor(ov::element::f32, ov::Shape{2, 2}, std::vector{5, 108, 375, 686})}, "loop_dynamic_inputs")), ReferenceLoopLayerTest::getTestCaseName); + +struct LoopStaticInputs : public LoopFunctionalBase { + std::shared_ptr create_function(const std::vector& loop_inputs, + const std::vector& results, + const int64_t& trip_count, + const std::vector& loop_in_type, + const ov::element::Type& net_type) override { + ov::ParameterVector loop_params; + for (auto&& input : loop_inputs) { + loop_params.emplace_back(std::make_shared(input.type, input.shape)); + } + + // Set up the cell body, a function from (Xi, Yi) -> (Zo) + // Body parameters + const std::vector body_params_shapes(loop_inputs.size(), ov::PartialShape::dynamic()); + ov::ParameterVector body_params; + for (const auto& pshape : body_params_shapes) { + body_params.emplace_back(std::make_shared(net_type, pshape)); + } + + const auto body_condition_const = std::make_shared(ov::element::boolean, ov::Shape{1}, true); + const auto exec_condition = std::make_shared(ov::element::boolean, ov::Shape{1}, true); + std::shared_ptr trip_count_input; + trip_count_input = std::make_shared(ov::element::i64, ov::Shape{1}, trip_count); + + // Body + std::shared_ptr Zo = body_params[0]; + for (int i = 1; i < body_params.size(); ++i) { + Zo = std::make_shared(body_params[i], Zo); + } + + const auto body = std::make_shared(ov::OutputVector{body_condition_const, Zo}, + body_params); + + const auto loop = std::make_shared(trip_count_input, exec_condition); + loop->set_function(body); + loop->set_special_body_ports(ov::opset8::Loop::SpecialBodyPorts{-1, 0}); + + for (int i = 0; i < body_params.size(); ++i) { + if (loop_in_type[i] == LOOP_IN_TYPE::INVARIANT) { + loop->set_invariant_input(body_params[i], loop_params[i]); + } else if (loop_in_type[i] == LOOP_IN_TYPE::MERGED) { + // todo: support several merged loop_inputs + // now supported only one in this sample + loop->set_merged_input(body_params[i], loop_params[i], Zo); + } + } + + // Output 0 is last Zo + const auto out0 = loop->get_iter_value(body_condition_const, -1); + const auto out1 = loop->get_iter_value(Zo, -1); + // Output 1 is concat of Zos + // start=0, stride=1, part_size=1, end=-1, axis=1 + const auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1); + + const auto result0 = std::make_shared(out0); + const auto result1 = std::make_shared(out1); + const auto result2 = std::make_shared(out2); + const auto function = std::make_shared(ov::ResultVector{result0, result1, result2}, loop_params, "loop"); + return function; + } +}; + +struct LoopStaticParams { + LoopStaticParams( + const std::shared_ptr& functional, + const std::vector& loop_inputs, + const std::vector& expected_results, + const int64_t& trip_count, + const std::vector& loop_in_type, + const ov::element::Type& net_type, + const std::string& test_case_name) + : function(functional), + inputs(loop_inputs), + expected_results(expected_results), + trip_count(trip_count), + loop_in_type(loop_in_type), + net_type(net_type), + test_case_name(test_case_name) {} + + std::shared_ptr function; + std::vector inputs; + std::vector expected_results; + int64_t trip_count; + std::vector loop_in_type; + ov::element::Type net_type; + std::string test_case_name; +}; + +class ReferenceLoopLayerStaticTest : public testing::TestWithParam, public reference_tests::CommonReferenceTest { +public: + void SetUp() override { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + auto params = GetParam(); + function = params.function->create_function(params.inputs, + params.expected_results, + params.trip_count, + params.loop_in_type, + params.net_type); + inputData.reserve(params.inputs.size()); + refOutData.reserve(params.expected_results.size()); + for (auto& input : params.inputs) { + inputData.push_back(input.data); + } + for (auto& output : params.expected_results) { + refOutData.push_back(output.data); + } + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + auto param = obj.param; + std::ostringstream result; + result << "TS="; + for (auto& input : param.inputs) { + result << CommonTestUtils::vec2str(input.shape) << "_"; + } + result << "_tripCount=" << param.trip_count; + result << "_loopInType="; + for (auto& type : param.loop_in_type) { + result << "_" << type; + } + result << "_netType=" << param.net_type; + if (!param.test_case_name.empty()) { + result << "_" << param.test_case_name; + } + return result.str(); + } +}; + +TEST_P(ReferenceLoopLayerStaticTest, CompareWithRefs) { + Exec(); +} + +template +std::vector generateParams() { + using T = typename ov::element_type_traits::value_type; + std::vector params { + LoopStaticParams( + std::make_shared(), + {reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2}), + reference_tests::Tensor( + ET, + {1, 1, 1}, + std::vector{7}), + reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2})}, + {reference_tests::Tensor( + ov::element::Type_t::boolean, + {1}, + std::vector{1}), + reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11}), + reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11})}, + 1, + {LOOP_IN_TYPE::INVARIANT, LOOP_IN_TYPE::INVARIANT, LOOP_IN_TYPE::MERGED}, + ET, + "loop_for_common"), + + LoopStaticParams( + std::make_shared(), + {reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2}), + reference_tests::Tensor( + ET, + {1, 1, 1}, + std::vector{7}), + reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, + 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, + 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2, 3, 0, 1, 6, 7, 4, 5, 2})}, + {reference_tests::Tensor( + ov::element::Type_t::boolean, + {1}, + std::vector{1}), + reference_tests::Tensor( + ET, + {10, 1, 10}, + std::vector{ + 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, + 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, + 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, + 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, + 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47}), + reference_tests::Tensor( + ET, + {10, 5, 10}, + std::vector{ + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 35, 26, 29, 20, 23, 14, 17, 32, 35, 26, + 49, 37, 41, 29, 33, 21, 25, 45, 49, 37, 63, 48, 53, 38, 43, 28, 33, 58, 63, 48, + 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 29, 20, 23, 14, 17, 32, 35, 26, 29, 20, 41, 29, 33, 21, 25, 45, 49, 37, 41, 29, + 53, 38, 43, 28, 33, 58, 63, 48, 53, 38, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, + + 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 23, 14, 17, 32, 35, 26, 29, 20, 23, 14, + 33, 21, 25, 45, 49, 37, 41, 29, 33, 21, 43, 28, 33, 58, 63, 48, 53, 38, 43, 28, + 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, + 17, 32, 35, 26, 29, 20, 23, 14, 17, 32, 25, 45, 49, 37, 41, 29, 33, 21, 25, 45, + 33, 58, 63, 48, 53, 38, 43, 28, 33, 58, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, + + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 35, 26, 29, 20, 23, 14, 17, 32, 35, 26, + 49, 37, 41, 29, 33, 21, 25, 45, 49, 37, 63, 48, 53, 38, 43, 28, 33, 58, 63, 48, + 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 29, 20, 23, 14, 17, 32, 35, 26, 29, 20, 41, 29, 33, 21, 25, 45, 49, 37, 41, 29, + 53, 38, 43, 28, 33, 58, 63, 48, 53, 38, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47, + + 13, 7, 9, 19, 21, 15, 17, 11, 13, 7, 23, 14, 17, 32, 35, 26, 29, 20, 23, 14, + 33, 21, 25, 45, 49, 37, 41, 29, 33, 21, 43, 28, 33, 58, 63, 48, 53, 38, 43, 28, + 53, 35, 41, 71, 77, 59, 65, 47, 53, 35, 9, 19, 21, 15, 17, 11, 13, 7, 9, 19, + 17, 32, 35, 26, 29, 20, 23, 14, 17, 32, 25, 45, 49, 37, 41, 29, 33, 21, 25, 45, + 33, 58, 63, 48, 53, 38, 43, 28, 33, 58, 41, 71, 77, 59, 65, 47, 53, 35, 41, 71, + + 21, 15, 17, 11, 13, 7, 9, 19, 21, 15, 35, 26, 29, 20, 23, 14, 17, 32, 35, 26, + 49, 37, 41, 29, 33, 21, 25, 45, 49, 37, 63, 48, 53, 38, 43, 28, 33, 58, 63, 48, + 77, 59, 65, 47, 53, 35, 41, 71, 77, 59, 17, 11, 13, 7, 9, 19, 21, 15, 17, 11, + 29, 20, 23, 14, 17, 32, 35, 26, 29, 20, 41, 29, 33, 21, 25, 45, 49, 37, 41, 29, + 53, 38, 43, 28, 33, 58, 63, 48, 53, 38, 65, 47, 53, 35, 41, 71, 77, 59, 65, 47})}, + 5, + {LOOP_IN_TYPE::INVARIANT, LOOP_IN_TYPE::INVARIANT, LOOP_IN_TYPE::MERGED}, + ET, + "loop_for_common"), + }; + return params; +} + +std::vector generateCombinedParams() { + const std::vector> generatedParams { + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + generateParams(), + }; + std::vector combinedParams; + + for (const auto& params : generatedParams) { + combinedParams.insert(combinedParams.end(), params.begin(), params.end()); + } + return combinedParams; +} + +INSTANTIATE_TEST_SUITE_P(smoke_Loop_With_Hardcoded_Refs, ReferenceLoopLayerStaticTest, + testing::ValuesIn(generateCombinedParams()), ReferenceLoopLayerStaticTest::getTestCaseName); +}