// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include #include "openvino/op/equal.hpp" #include "comparison.hpp" using namespace ov; using ComparisonTypes = ngraph::helpers::ComparisonTypes; namespace reference_tests { namespace ComparisonOpsRefTestDefinitions { namespace { template std::vector generateComparisonParams(const element::Type& type) { using T = typename element_type_traits::value_type; std::vector compParams { // 1D // 2D // 3D // 4D Builder {} .compType(ComparisonTypes::EQUAL) .input1({{2, 2}, type, std::vector {0, 12, 23, 0}}) .input2({{2, 2}, type, std::vector {0, 12, 23, 0}}) .expected({{2, 2}, element::boolean, std::vector {1, 1, 1, 1}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{2, 3}, type, std::vector {0, 6, 45, 1, 21, 21}}) .input2({{2, 3}, type, std::vector {1, 18, 23, 1, 19, 21}}) .expected({{2, 3}, element::boolean, std::vector {0, 0, 0, 1, 0, 1}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{1}, type, std::vector {53}}) .input2({{1}, type, std::vector {53}}) .expected({{1}, element::boolean, std::vector {1}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{2, 4}, type, std::vector {0, 12, 23, 0, 1, 5, 11, 8}}) .input2({{2, 4}, type, std::vector {0, 12, 23, 0, 10, 5, 11, 8}}) .expected({{2, 4}, element::boolean, std::vector {1, 1, 1, 1, 0, 1, 1, 1}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{3, 1, 2}, type, std::vector {2, 1, 4, 1, 3, 1}}) .input2({{1, 2, 1}, type, std::vector {1, 1}}) .expected({{3, 2, 2}, element::boolean, std::vector {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{2, 1, 2, 1}, type, std::vector {2, 1, 4, 1}}) .input2({{1, 2, 1}, type, std::vector {1, 1}}) .expected({{2, 1, 2, 1}, element::boolean, std::vector {0, 1, 0, 1}})}; return compParams; } std::vector generateComparisonCombinedParams() { const std::vector> compTypeParams { generateComparisonParams(element::f32), generateComparisonParams(element::f16), generateComparisonParams(element::i32), generateComparisonParams(element::u32), generateComparisonParams(element::boolean)}; std::vector combinedParams; for (const auto& params : compTypeParams) { combinedParams.insert(combinedParams.end(), params.begin(), params.end()); } return combinedParams; } INSTANTIATE_TEST_SUITE_P(smoke_Comparison_With_Hardcoded_Refs, ReferenceComparisonLayerTest, ::testing::ValuesIn(generateComparisonCombinedParams()), ReferenceComparisonLayerTest::getTestCaseName); template std::vector generateNumericParams(const element::Type& type) { using T = typename element_type_traits::value_type; std::vector compParams { Builder {} .compType(ComparisonTypes::EQUAL) .input1({{4}, type, std::vector {-2.5f, 25.5f, 2.25f, NAN}}) .input2({{4}, type, std::vector {10.0f, 5.0f, 2.25f, 10.0f}}) .expected({{4}, element::boolean, std::vector {0, 0, 1, 0, }}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{2, 3}, type, std::vector {0.0f, NAN, NAN, 1.0f, 21.0f, -INFINITY}}) .input2({{2, 3}, type, std::vector {1.0f, NAN, 23.0f, 1.0f, 19.0f, 21.0f}}) .expected({{2, 3}, element::boolean, std::vector {0, 0, 0, 1, 0, 0}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{1}, type, std::vector {INFINITY}}) .input2({{1}, type, std::vector {INFINITY}}) .expected({{1}, element::boolean, std::vector {1}}), Builder {} .compType(ComparisonTypes::EQUAL) .input1({{5}, type, std::vector {-2.5f, 25.5f, 2.25f, INFINITY, 6.0f}}) .input2({{5}, type, std::vector {10.0f, 5.0f, 2.25f, 10.0f, -INFINITY}}) .expected({{5}, element::boolean, std::vector {0, 0, 1, 0, 0}})}; return compParams; } std::vector generateNumericCombinedParams() { const std::vector> compTypeParams { generateNumericParams(element::f16), generateNumericParams(element::f32)}; std::vector combinedParams; for (const auto& params : compTypeParams) { combinedParams.insert(combinedParams.end(), params.begin(), params.end()); } return combinedParams; } INSTANTIATE_TEST_SUITE_P(smoke_Numeric_With_Hardcoded_Refs, ReferenceComparisonLayerTest, ::testing::ValuesIn(generateNumericCombinedParams()), ReferenceComparisonLayerTest::getTestCaseName); } // namespace } // namespace ComparisonOpsRefTestDefinitions } // namespace reference_tests