[ONNX] Add support for BitShift operator (#4368)

This commit is contained in:
Michał Karzyński 2021-02-17 16:39:27 +01:00 committed by GitHub
parent 45ae389842
commit ec9b5894fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 12 deletions

View File

@ -0,0 +1,67 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "op/bitshift.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
OutputVector bitshift(const Node& node)
{
const Output<ngraph::Node> input_x = node.get_ng_inputs().at(0);
const Output<ngraph::Node> input_y = node.get_ng_inputs().at(1);
std::string direction = node.get_attribute_value<std::string>("direction", "");
CHECK_VALID_NODE(node,
!direction.empty(),
"Required attribute 'direction' is not specified.");
CHECK_VALID_NODE(node,
direction == "LEFT" || direction == "RIGHT",
"Only values 'LEFT' and 'RIGHT' are supported for 'direction' "
"attribute. Given: ",
direction);
auto shift = std::make_shared<default_opset::Power>(
default_opset::Constant::create(input_y.get_element_type(), Shape{1}, {2}),
input_y);
if (direction == "RIGHT")
{
return {std::make_shared<default_opset::Divide>(input_x, shift)};
}
else
{
return {std::make_shared<default_opset::Multiply>(input_x, shift)};
}
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -0,0 +1,40 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
OutputVector bitshift(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -35,6 +35,7 @@
#include "op/atanh.hpp"
#include "op/average_pool.hpp"
#include "op/batch_norm.hpp"
#include "op/bitshift.hpp"
#include "op/cast.hpp"
#include "op/ceil.hpp"
#include "op/clip.hpp"
@ -324,6 +325,7 @@ namespace ngraph
REGISTER_OPERATOR("Atanh", 1, atanh);
REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
REGISTER_OPERATOR("BitShift", 1, bitshift);
REGISTER_OPERATOR("Cast", 1, cast);
REGISTER_OPERATOR("Ceil", 1, ceil);
REGISTER_OPERATOR("Clip", 1, clip);

View File

@ -39,8 +39,6 @@ xfail_issue_33488 = xfail_test(reason="RuntimeError: nGraph does not support the
"MaxUnpool")
xfail_issue_33512 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"Einsum")
xfail_issue_33515 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"BitShift")
xfail_issue_33535 = xfail_test(reason="nGraph does not support the following ONNX operations:"
"DynamicQuantizeLinear")
xfail_issue_33538 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"

View File

@ -23,7 +23,6 @@ from tests.test_onnx.utils.onnx_backend import OpenVinoTestBackend
from tests import (BACKEND_NAME,
xfail_issue_33488,
xfail_issue_33512,
xfail_issue_33515,
xfail_issue_33535,
xfail_issue_33538,
xfail_issue_33540,
@ -562,15 +561,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_compress_default_axis_cpu",
"OnnxBackendNodeModelTest.test_compress_1_cpu",
"OnnxBackendNodeModelTest.test_compress_0_cpu"),
(xfail_issue_33515,
"OnnxBackendNodeModelTest.test_bitshift_left_uint8_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint64_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint16_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint32_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint8_cpu",
"OnnxBackendNodeModelTest.test_bitshift_left_uint32_cpu",
"OnnxBackendNodeModelTest.test_bitshift_left_uint16_cpu",
"OnnxBackendNodeModelTest.test_bitshift_left_uint64_cpu"),
(xfail_issue_38732,
"OnnxBackendNodeModelTest.test_convinteger_with_padding_cpu",
"OnnxBackendNodeModelTest.test_basic_convinteger_cpu"),