// Copyright (C) 2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include #include "openvino/op/gather_tree.hpp" #include "base_reference_test.hpp" using namespace reference_tests; using namespace ov; namespace { struct GatherTreeParams { template GatherTreeParams(const ov::Shape inShape, std::vector stepIds, const std::vector parentIds, const std::vector maxSeqLen, const std::vector endToken, std::vector output) : stepIdsTensor(inShape, element::from(), stepIds), parentIdsTensor(inShape, element::from(), parentIds), maxSeqLenTensor(ov::Shape{inShape[1]}, element::from(), maxSeqLen), endTokenTensor(ov::Shape{}, element::from(), endToken), expectedTensor(inShape, element::from(), output) {} Tensor stepIdsTensor; Tensor parentIdsTensor; Tensor maxSeqLenTensor; Tensor endTokenTensor; Tensor expectedTensor; }; class ReferenceGatherTreeTest : public testing::TestWithParam, public CommonReferenceTest { public: void SetUp() override { auto params = GetParam(); function = CreateFunction(params); inputData = {params.stepIdsTensor.data, params.parentIdsTensor.data, params.maxSeqLenTensor.data, params.endTokenTensor.data}; refOutData = {params.expectedTensor.data}; } static std::string getTestCaseName(const testing::TestParamInfo& obj) { auto param = obj.param; std::ostringstream result; result << "iType=" << param.stepIdsTensor.type << "_"; result << "iShape=" << param.stepIdsTensor.shape; return result.str(); } private: static std::shared_ptr CreateFunction(const GatherTreeParams& params) { const auto stepIds = std::make_shared(params.stepIdsTensor.type, params.stepIdsTensor.shape); const auto parentIds = std::make_shared(params.parentIdsTensor.type, params.parentIdsTensor.shape); const auto maxSeqLen = std::make_shared(params.maxSeqLenTensor.type, params.maxSeqLenTensor.shape); const auto endToken = std::make_shared(params.endTokenTensor.type, params.endTokenTensor.shape); const auto gatherTree = std::make_shared(stepIds, parentIds, maxSeqLen, endToken); return std::make_shared(NodeVector {gatherTree}, ParameterVector {stepIds, parentIds, maxSeqLen, endToken}); } }; TEST_P(ReferenceGatherTreeTest, CompareWithRefs) { Exec(); } template std::vector generateGatherTreeParams() { using T = typename element_type_traits::value_type; std::vector gatherTreeParams { GatherTreeParams(Shape{4, 1, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1}, std::vector{0, 0, 0, 0, 1, 1, 2, 1, 2, -1, -1, -1}, std::vector{3}, std::vector{10}, std::vector{2, 2, 2, 6, 5, 6, 7, 8, 9, 10, 10, 10}), GatherTreeParams(Shape{2, 2, 2}, std::vector{1, 2, 3, 4, 5, 6, 7, 8}, std::vector{0, 0, 0, 0, 0, 0, 0, 0}, std::vector{2, 4}, std::vector{0}, std::vector{1, 1, 3, 3, 5, 6, 7, 8}) }; return gatherTreeParams; } std::vector generateGatherTreeCombinedParams() { const std::vector> gatherTreeTypeParams { generateGatherTreeParams(), generateGatherTreeParams()}; std::vector combinedParams; for (const auto& params : gatherTreeTypeParams) { combinedParams.insert(combinedParams.end(), params.begin(), params.end()); } return combinedParams; } INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_With_Hardcoded_Refs, ReferenceGatherTreeTest, testing::ValuesIn(generateGatherTreeCombinedParams()), ReferenceGatherTreeTest::getTestCaseName); } // namespace