// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include #include "openvino/op/greater_eq.hpp" #include "comparison.hpp" using namespace ov; using ComparisonTypes = ngraph::helpers::ComparisonTypes; namespace reference_tests { namespace ComparisonOpsRefTestDefinitions { namespace { template std::vector generateComparisonParams(const element::Type& type) { using T = typename element_type_traits::value_type; std::vector compParams { // 1D // 2D // 3D // 4D Builder {} .compType(ComparisonTypes::GREATER_EQUAL) .input1({{2, 2}, type, std::vector {0, 12, 23, 0}}) .input2({{2, 2}, type, std::vector {0, 12, 23, 0}}) .expected({{2, 2}, element::boolean, std::vector {1, 1, 1, 1}}), Builder {} .compType(ComparisonTypes::GREATER_EQUAL) .input1({{2, 3}, type, std::vector {0, 6, 45, 1, 21, 21}}) .input2({{2, 3}, type, std::vector {1, 18, 23, 1, 19, 21}}) .expected({{2, 3}, element::boolean, std::vector {0, 0, 1, 1, 1, 1}}), Builder {} .compType(ComparisonTypes::GREATER_EQUAL) .input1({{1}, type, std::vector {53}}) .input2({{1}, type, std::vector {53}}) .expected({{1}, element::boolean, std::vector {1}}), Builder {} .compType(ComparisonTypes::GREATER_EQUAL) .input1({{2, 4}, type, std::vector {0, 12, 23, 0, 1, 5, 12, 8}}) .input2({{2, 4}, type, std::vector {0, 12, 23, 0, 10, 5, 11, 8}}) .expected({{2, 4}, element::boolean, std::vector {1, 1, 1, 1, 0, 1, 1, 1}}), Builder {} .compType(ComparisonTypes::GREATER_EQUAL) .input1({{3, 1, 2}, type, std::vector {2, 1, 4, 1, 3, 1}}) .input2({{1, 2, 1}, type, std::vector {1, 1}}) .expected({{3, 2, 2}, element::boolean, std::vector {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}), Builder {} .compType(ComparisonTypes::GREATER_EQUAL) .input1({{2, 1, 2, 1}, type, std::vector {2, 1, 4, 1}}) .input2({{1, 2, 1}, type, std::vector {1, 1}}) .expected({{2, 1, 2, 1}, element::boolean, std::vector {1, 1, 1, 1}})}; return compParams; } std::vector generateComparisonCombinedParams() { const std::vector> compTypeParams { generateComparisonParams(element::f32), generateComparisonParams(element::f16), generateComparisonParams(element::i64), generateComparisonParams(element::i32), generateComparisonParams(element::u64), generateComparisonParams(element::u32), generateComparisonParams(element::boolean)}; std::vector combinedParams; for (const auto& params : compTypeParams) { combinedParams.insert(combinedParams.end(), params.begin(), params.end()); } return combinedParams; } } // namespace INSTANTIATE_TEST_SUITE_P(smoke_Comparison_With_Hardcoded_Refs, ReferenceComparisonLayerTest, ::testing::ValuesIn(generateComparisonCombinedParams()), ReferenceComparisonLayerTest::getTestCaseName); } // namespace ComparisonOpsRefTestDefinitions } // namespace reference_tests