Add support for com.microsoft.BiasGelu operator (#7480)

This commit is contained in:
Mateusz Tabaka
2021-09-14 13:02:31 +02:00
committed by GitHub
parent ba34a1989c
commit 2c4009e3d8
5 changed files with 124 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op/com.microsoft/bias_gelu.hpp"
#include "default_opset.hpp"
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector bias_gelu(const Node& node) {
auto nodes = node.get_ng_inputs();
NGRAPH_CHECK(nodes.size() == 2, "BiasGelu takes 2 inputs. Provided " + std::to_string(nodes.size()));
return {std::make_shared<default_opset::Gelu>(std::make_shared<default_opset::Add>(nodes.at(0), nodes.at(1)))};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@@ -0,0 +1,17 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "onnx_import/core/node.hpp"
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector bias_gelu(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@@ -29,6 +29,7 @@
#include "op/cast_like.hpp"
#include "op/ceil.hpp"
#include "op/clip.hpp"
#include "op/com.microsoft/bias_gelu.hpp"
#include "op/compress.hpp"
#include "op/concat.hpp"
#include "op/constant.hpp"
@@ -261,6 +262,8 @@ bool OperatorsBridge::_is_operator_registered(const std::string& name,
}
}
static const char* const MICROSOFT_DOMAIN = "com.microsoft";
#define REGISTER_OPERATOR(name_, ver_, fn_) \
m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1))
@@ -472,6 +475,8 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBoxClustered", 1, prior_box_clustered);
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Swish", 1, swish);
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "BiasGelu", 1, bias_gelu);
}
#undef REGISTER_OPERATOR

View File

@@ -0,0 +1,61 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "Y"
output: "out"
name: "Gelu_AddBias_1"
op_type: "BiasGelu"
domain: "com.microsoft"
}
name: "test_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@@ -4150,3 +4150,23 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_random_uniform_like) {
test_case.add_input<ngraph::float16>(Shape{2, 2}, {41, 42, 43, 44});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_bias_gelu) {
const auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/bias_gelu.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.5488135,
0.71518934,
0.60276335,
0.5448832,
0.4236548,
0.6458941,
0.4375872,
0.891773,
0.96366274,
0.3834415});
test_case.add_input<float>({0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606});
test_case.add_expected_output<float>(
{1.2198428, 1.1112978, 1.0293297, 1.366493, 0.3411342, 1.329408, 0.8051748, 1.354462, 1.8336612, 0.3068893});
test_case.run();
}