[ONNX] Extend ONNX Frontend with Function GreaterOrEqual

* ONNX greater_or_equal enabled
This commit is contained in:
Siddhant Chauhan 2023-11-15 14:04:31 +05:30 committed by GitHub
parent 5365bfe0a7
commit aa1fcbbad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 208 additions and 0 deletions

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op/greater_or_equal.hpp"
#include <memory>
#include <vector>
#include "default_opset.hpp"
OPENVINO_SUPPRESS_DEPRECATED_START
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector greater_or_equal(const Node& node) {
const auto A = node.get_ng_inputs().at(0);
const auto B = node.get_ng_inputs().at(1);
NGRAPH_CHECK(A.get_element_type() != ov::element::bf16 && B.get_element_type() != ov::element::bf16,
"The input data bfloat16 isn't supported in opset 12");
const auto C = std::make_shared<default_opset::GreaterEqual>(A, B);
return {C};
}
} // namespace set_1
namespace set_16 {
OutputVector greater_or_equal(const Node& node) {
const auto A = node.get_ng_inputs().at(0);
const auto B = node.get_ng_inputs().at(1);
const auto C = std::make_shared<default_opset::GreaterEqual>(A, B);
return {C};
}
} // namespace set_16
} // namespace op
} // namespace onnx_import
} // namespace ngraph
OPENVINO_SUPPRESS_DEPRECATED_END

View File

@ -0,0 +1,28 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/deprecated.hpp"
OPENVINO_SUPPRESS_DEPRECATED_START
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector greater_or_equal(const Node& node);
} // namespace set_1
namespace set_16 {
OutputVector greater_or_equal(const Node& node);
} // namespace set_16
} // namespace op
} // namespace onnx_import
} // namespace ngraph
OPENVINO_SUPPRESS_DEPRECATED_END

View File

@ -74,6 +74,7 @@
#include "op/global_average_pool.hpp" #include "op/global_average_pool.hpp"
#include "op/global_max_pool.hpp" #include "op/global_max_pool.hpp"
#include "op/greater.hpp" #include "op/greater.hpp"
#include "op/greater_or_equal.hpp"
#include "op/grid_sample.hpp" #include "op/grid_sample.hpp"
#include "op/group_normalization.hpp" #include "op/group_normalization.hpp"
#include "op/gru.hpp" #include "op/gru.hpp"
@ -395,6 +396,8 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool); REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool); REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater); REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("Greater_Or_Equal", 1, greater_or_equal);
REGISTER_OPERATOR("Greater_Or_Equal", 16, greater_or_equal);
REGISTER_OPERATOR("GridSample", 1, grid_sample); REGISTER_OPERATOR("GridSample", 1, grid_sample);
REGISTER_OPERATOR("GroupNormalization", 1, group_normalization); REGISTER_OPERATOR("GroupNormalization", 1, group_normalization);
REGISTER_OPERATOR("GRU", 1, gru); REGISTER_OPERATOR("GRU", 1, gru);

View File

@ -0,0 +1,53 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "C"
op_type: "Greater_Or_Equal"
}
name: "test_greater_or_equal_float"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "C"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 16
}

View File

@ -0,0 +1,53 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "C"
op_type: "Greater_Or_Equal"
}
name: "test_greater_or_equal_int"
input {
name: "A"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "C"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 16
}

View File

@ -6976,3 +6976,31 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_mm_nms_rotated) {
test_case.run(); test_case.run();
} }
OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_int) {
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/greater_or_equal_int.onnx"));
auto test_case = ov::test::TestCase(function, s_device);
test_case.add_input<int64_t>(Shape{2}, {10, 20});
test_case.add_input<int64_t>(Shape{2}, {15, 15});
test_case.add_expected_output<bool>(Shape{2}, {false, true});
test_case.run();
}
OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_float) {
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/greater_or_equal_float.onnx"));
auto test_case = ov::test::TestCase(function, s_device);
test_case.add_input<float>(Shape{2}, {12.03513f, 22.03513f});
test_case.add_input<float>(Shape{2}, {5.84916f, 22.03513f});
test_case.add_expected_output<bool>(Shape{2}, {true, true});
test_case.run();
}