// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include #include "base_reference_test.hpp" #include "openvino/op/subtract.hpp" using namespace ov; using namespace reference_tests; namespace { struct SubtractParams { template SubtractParams(const PartialShape& iShape1, const PartialShape& iShape2, const element::Type& iType, const std::vector& iValues1, const std::vector& iValues2, const std::vector& oValues) : pshape1(iShape1), pshape2(iShape2), inType(iType), outType(iType), inputData1(CreateTensor(iType, iValues1)), inputData2(CreateTensor(iType, iValues2)), refData(CreateTensor(iType, oValues)) {} PartialShape pshape1; PartialShape pshape2; element::Type inType; element::Type outType; runtime::Tensor inputData1; runtime::Tensor inputData2; runtime::Tensor refData; }; class ReferenceSubtractLayerTest : public testing::TestWithParam, public CommonReferenceTest { public: void SetUp() override { auto params = GetParam(); function = CreateFunction(params.pshape1, params.pshape2, params.inType, params.outType); inputData = {params.inputData1, params.inputData2}; refOutData = {params.refData}; } static std::string getTestCaseName(const testing::TestParamInfo& obj) { auto param = obj.param; std::ostringstream result; result << "iShape1=" << param.pshape1 << "_"; result << "iShape2=" << param.pshape2 << "_"; result << "iType=" << param.inType << "_"; result << "oType=" << param.outType; return result.str(); } private: static std::shared_ptr CreateFunction(const PartialShape& input_shape1, const PartialShape& input_shape2, const element::Type& input_type, const element::Type& expected_output_type) { const auto in1 = std::make_shared(input_type, input_shape1); const auto in2 = std::make_shared(input_type, input_shape2); const auto subtract = std::make_shared(in1, in2); return std::make_shared(NodeVector{subtract}, ParameterVector{in1, in2}); } }; TEST_P(ReferenceSubtractLayerTest, SubtractWithHardcodedRefs) { Exec(); } template std::vector generateParamsForSubtract() { using T = typename element_type_traits::value_type; std::vector params{ SubtractParams(ov::PartialShape{2, 2}, ov::PartialShape{2, 2}, IN_ET, std::vector{2, 4, 8, 16}, std::vector{1, 2, 4, 8}, std::vector{1, 2, 4, 8}), SubtractParams(ov::PartialShape{3, 2, 1}, ov::PartialShape{1, 6}, IN_ET, std::vector{12, 24, 36, 48, 60, 72}, std::vector{1, 2, 3, 4, 6, 1}, std::vector{11, 10, 9, 8, 6, 11, 23, 22, 21, 20, 18, 23, 35, 34, 33, 32, 30, 35, 47, 46, 45, 44, 42, 47, 59, 58, 57, 56, 54, 59, 71, 70, 69, 68, 66, 71}), SubtractParams(ov::PartialShape{1}, ov::PartialShape{1}, IN_ET, std::vector{8}, std::vector{2}, std::vector{6}) }; return params; } template std::vector generateParamsForSubtractFloat() { using T = typename element_type_traits::value_type; std::vector params{ SubtractParams(ov::PartialShape{1}, ov::PartialShape{1}, IN_ET, std::vector{3.1}, std::vector{8}, std::vector{-4.9}) }; return params; } std::vector generateCombinedParamsForSubtract() { const std::vector> allTypeParams{ generateParamsForSubtract(), generateParamsForSubtract(), generateParamsForSubtract(), generateParamsForSubtract(), generateParamsForSubtract(), generateParamsForSubtract(), generateParamsForSubtract() }; std::vector combinedParams; for (const auto& params : allTypeParams) { combinedParams.insert(combinedParams.end(), params.begin(), params.end()); } return combinedParams; } std::vector generateCombinedParamsForSubtractFloat() { const std::vector> allTypeParams{ generateParamsForSubtractFloat(), generateParamsForSubtractFloat() }; std::vector combinedParams; for (const auto& params : allTypeParams) { combinedParams.insert(combinedParams.end(), params.begin(), params.end()); } return combinedParams; } INSTANTIATE_TEST_SUITE_P( smoke_Subtract_With_Hardcoded_Refs, ReferenceSubtractLayerTest, ::testing::ValuesIn(generateCombinedParamsForSubtract()), ReferenceSubtractLayerTest::getTestCaseName); INSTANTIATE_TEST_SUITE_P( smoke_Subtract_Float_With_Hardcoded_Refs, ReferenceSubtractLayerTest, ::testing::ValuesIn(generateCombinedParamsForSubtractFloat()), ReferenceSubtractLayerTest::getTestCaseName); } // namespace