[ONNX] Extend ONNX Frontend with Function GreaterOrEqual
* ONNX greater_or_equal enabled
This commit is contained in:
parent
5365bfe0a7
commit
aa1fcbbad1
43
src/frontends/onnx/frontend/src/op/greater_or_equal.cpp
Normal file
43
src/frontends/onnx/frontend/src/op/greater_or_equal.cpp
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "op/greater_or_equal.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "default_opset.hpp"
|
||||||
|
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
namespace ngraph {
|
||||||
|
namespace onnx_import {
|
||||||
|
namespace op {
|
||||||
|
namespace set_1 {
|
||||||
|
OutputVector greater_or_equal(const Node& node) {
|
||||||
|
const auto A = node.get_ng_inputs().at(0);
|
||||||
|
const auto B = node.get_ng_inputs().at(1);
|
||||||
|
|
||||||
|
NGRAPH_CHECK(A.get_element_type() != ov::element::bf16 && B.get_element_type() != ov::element::bf16,
|
||||||
|
"The input data bfloat16 isn't supported in opset 12");
|
||||||
|
|
||||||
|
const auto C = std::make_shared<default_opset::GreaterEqual>(A, B);
|
||||||
|
|
||||||
|
return {C};
|
||||||
|
}
|
||||||
|
} // namespace set_1
|
||||||
|
|
||||||
|
namespace set_16 {
|
||||||
|
OutputVector greater_or_equal(const Node& node) {
|
||||||
|
const auto A = node.get_ng_inputs().at(0);
|
||||||
|
const auto B = node.get_ng_inputs().at(1);
|
||||||
|
|
||||||
|
const auto C = std::make_shared<default_opset::GreaterEqual>(A, B);
|
||||||
|
|
||||||
|
return {C};
|
||||||
|
}
|
||||||
|
} // namespace set_16
|
||||||
|
} // namespace op
|
||||||
|
} // namespace onnx_import
|
||||||
|
} // namespace ngraph
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
28
src/frontends/onnx/frontend/src/op/greater_or_equal.hpp
Normal file
28
src/frontends/onnx/frontend/src/op/greater_or_equal.hpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/deprecated.hpp"
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
|
||||||
|
#include "ngraph/node.hpp"
|
||||||
|
#include "onnx_import/core/node.hpp"
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace onnx_import {
|
||||||
|
namespace op {
|
||||||
|
namespace set_1 {
|
||||||
|
OutputVector greater_or_equal(const Node& node);
|
||||||
|
|
||||||
|
} // namespace set_1
|
||||||
|
|
||||||
|
namespace set_16 {
|
||||||
|
OutputVector greater_or_equal(const Node& node);
|
||||||
|
|
||||||
|
} // namespace set_16
|
||||||
|
} // namespace op
|
||||||
|
} // namespace onnx_import
|
||||||
|
} // namespace ngraph
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
@ -74,6 +74,7 @@
|
|||||||
#include "op/global_average_pool.hpp"
|
#include "op/global_average_pool.hpp"
|
||||||
#include "op/global_max_pool.hpp"
|
#include "op/global_max_pool.hpp"
|
||||||
#include "op/greater.hpp"
|
#include "op/greater.hpp"
|
||||||
|
#include "op/greater_or_equal.hpp"
|
||||||
#include "op/grid_sample.hpp"
|
#include "op/grid_sample.hpp"
|
||||||
#include "op/group_normalization.hpp"
|
#include "op/group_normalization.hpp"
|
||||||
#include "op/gru.hpp"
|
#include "op/gru.hpp"
|
||||||
@ -395,6 +396,8 @@ OperatorsBridge::OperatorsBridge() {
|
|||||||
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
|
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
|
||||||
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
|
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
|
||||||
REGISTER_OPERATOR("Greater", 1, greater);
|
REGISTER_OPERATOR("Greater", 1, greater);
|
||||||
|
REGISTER_OPERATOR("Greater_Or_Equal", 1, greater_or_equal);
|
||||||
|
REGISTER_OPERATOR("Greater_Or_Equal", 16, greater_or_equal);
|
||||||
REGISTER_OPERATOR("GridSample", 1, grid_sample);
|
REGISTER_OPERATOR("GridSample", 1, grid_sample);
|
||||||
REGISTER_OPERATOR("GroupNormalization", 1, group_normalization);
|
REGISTER_OPERATOR("GroupNormalization", 1, group_normalization);
|
||||||
REGISTER_OPERATOR("GRU", 1, gru);
|
REGISTER_OPERATOR("GRU", 1, gru);
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "A"
|
||||||
|
input: "B"
|
||||||
|
output: "C"
|
||||||
|
op_type: "Greater_Or_Equal"
|
||||||
|
}
|
||||||
|
name: "test_greater_or_equal_float"
|
||||||
|
input {
|
||||||
|
name: "A"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "C"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 2
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 16
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "A"
|
||||||
|
input: "B"
|
||||||
|
output: "C"
|
||||||
|
op_type: "Greater_Or_Equal"
|
||||||
|
}
|
||||||
|
name: "test_greater_or_equal_int"
|
||||||
|
input {
|
||||||
|
name: "A"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "C"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 2
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 16
|
||||||
|
}
|
@ -6976,3 +6976,31 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_mm_nms_rotated) {
|
|||||||
|
|
||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_int) {
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/greater_or_equal_int.onnx"));
|
||||||
|
|
||||||
|
auto test_case = ov::test::TestCase(function, s_device);
|
||||||
|
|
||||||
|
test_case.add_input<int64_t>(Shape{2}, {10, 20});
|
||||||
|
test_case.add_input<int64_t>(Shape{2}, {15, 15});
|
||||||
|
test_case.add_expected_output<bool>(Shape{2}, {false, true});
|
||||||
|
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_float) {
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/greater_or_equal_float.onnx"));
|
||||||
|
|
||||||
|
auto test_case = ov::test::TestCase(function, s_device);
|
||||||
|
|
||||||
|
test_case.add_input<float>(Shape{2}, {12.03513f, 22.03513f});
|
||||||
|
test_case.add_input<float>(Shape{2}, {5.84916f, 22.03513f});
|
||||||
|
test_case.add_expected_output<bool>(Shape{2}, {true, true});
|
||||||
|
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user