Add atanh to onnx importer opset4 (#1425)

This commit is contained in:
Jan Iwaszkiewicz 2020-07-28 10:58:53 +02:00 committed by GitHub
parent 0b1ef99fd7
commit 12457ca85b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 69 deletions

View File

@ -53,7 +53,6 @@ add_library(onnx_importer SHARED
op/asin.hpp
op/asinh.hpp
op/atan.hpp
op/atanh.cpp
op/atanh.hpp
op/average_pool.cpp
op/average_pool.hpp

View File

@ -1,64 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "atanh.hpp"
#include "default_opset.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector atanh(const Node& node)
{
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
// Define inverse hyperbolic tangent in terms of natural logarithm:
//
// atanh(x) = 0.5 * ln((1 + x) / (1 - x))
//
const auto one =
default_opset::Constant::create(data->get_element_type(), {}, {1.f});
const auto half =
default_opset::Constant::create(data->get_element_type(), {}, {0.5f});
const auto one_plus_x = std::make_shared<default_opset::Add>(one, data);
const auto one_minus_x = std::make_shared<default_opset::Subtract>(one, data);
const auto log_args =
std::make_shared<default_opset::Divide>(one_plus_x, one_minus_x);
const auto log_node = std::make_shared<default_opset::Log>(log_args);
return {std::make_shared<default_opset::Multiply>(half, log_node)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -17,6 +17,7 @@
#pragma once
#include "core/node.hpp"
#include "default_opset.hpp"
#include "ngraph/node.hpp"
namespace ngraph
@ -27,10 +28,12 @@ namespace ngraph
{
namespace set_1
{
NodeVector atanh(const Node& node);
inline NodeVector atanh(const Node& node)
{
return {std::make_shared<default_opset::Atanh>(node.get_ng_inputs().at(0))};
}
} // namespace set_1
} // namespace op
} // namespace op
} // namespace onnx_import

View File

@ -158,4 +158,3 @@ NGRAPH_OP(Atanh, ngraph::op::v3)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(Asinh, ngraph::op::v3)