Add support for com.microsoft.BiasGelu operator (#7480)
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "op/com.microsoft/bias_gelu.hpp"
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
OutputVector bias_gelu(const Node& node) {
|
||||
auto nodes = node.get_ng_inputs();
|
||||
NGRAPH_CHECK(nodes.size() == 2, "BiasGelu takes 2 inputs. Provided " + std::to_string(nodes.size()));
|
||||
return {std::make_shared<default_opset::Gelu>(std::make_shared<default_opset::Add>(nodes.at(0), nodes.at(1)))};
|
||||
}
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "onnx_import/core/node.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
OutputVector bias_gelu(const Node& node);
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "op/cast_like.hpp"
|
||||
#include "op/ceil.hpp"
|
||||
#include "op/clip.hpp"
|
||||
#include "op/com.microsoft/bias_gelu.hpp"
|
||||
#include "op/compress.hpp"
|
||||
#include "op/concat.hpp"
|
||||
#include "op/constant.hpp"
|
||||
@@ -261,6 +262,8 @@ bool OperatorsBridge::_is_operator_registered(const std::string& name,
|
||||
}
|
||||
}
|
||||
|
||||
static const char* const MICROSOFT_DOMAIN = "com.microsoft";
|
||||
|
||||
#define REGISTER_OPERATOR(name_, ver_, fn_) \
|
||||
m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1))
|
||||
|
||||
@@ -472,6 +475,8 @@ OperatorsBridge::OperatorsBridge() {
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBoxClustered", 1, prior_box_clustered);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Swish", 1, swish);
|
||||
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "BiasGelu", 1, bias_gelu);
|
||||
}
|
||||
|
||||
#undef REGISTER_OPERATOR
|
||||
|
||||
61
ngraph/test/models/onnx/bias_gelu.prototxt
Normal file
61
ngraph/test/models/onnx/bias_gelu.prototxt
Normal file
@@ -0,0 +1,61 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "X"
|
||||
input: "Y"
|
||||
output: "out"
|
||||
name: "Gelu_AddBias_1"
|
||||
op_type: "BiasGelu"
|
||||
domain: "com.microsoft"
|
||||
}
|
||||
name: "test_graph"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -4150,3 +4150,23 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_random_uniform_like) {
|
||||
test_case.add_input<ngraph::float16>(Shape{2, 2}, {41, 42, 43, 44});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_bias_gelu) {
|
||||
const auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/bias_gelu.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({0.5488135,
|
||||
0.71518934,
|
||||
0.60276335,
|
||||
0.5448832,
|
||||
0.4236548,
|
||||
0.6458941,
|
||||
0.4375872,
|
||||
0.891773,
|
||||
0.96366274,
|
||||
0.3834415});
|
||||
test_case.add_input<float>({0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606});
|
||||
test_case.add_expected_output<float>(
|
||||
{1.2198428, 1.1112978, 1.0293297, 1.366493, 0.3411342, 1.329408, 0.8051748, 1.354462, 1.8336612, 0.3068893});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user