Change Elu a regular op since decomposition works extremely slowly (#582)
* Moved Elu operation from Fused to regular ones because the decomposition works extremely slowly. * Added reference implementation for the Elu op
This commit is contained in:
parent
73f3b7c8fc
commit
c1625743df
@ -189,6 +189,8 @@ set (SRC
|
||||
op/divide.hpp
|
||||
op/dot.cpp
|
||||
op/dot.hpp
|
||||
op/elu.cpp
|
||||
op/elu.hpp
|
||||
op/embeddingbag_offsets_sum.cpp
|
||||
op/embeddingbag_offsets_sum.hpp
|
||||
op/embedding_lookup.cpp
|
||||
@ -430,8 +432,6 @@ set (SRC
|
||||
op/fused/hard_sigmoid.hpp
|
||||
op/fused/depth_to_space.cpp
|
||||
op/fused/depth_to_space.hpp
|
||||
op/fused/elu.cpp
|
||||
op/fused/elu.hpp
|
||||
op/fused/fake_quantize.cpp
|
||||
op/fused/fake_quantize.hpp
|
||||
op/fused/gelu.cpp
|
||||
|
@ -13,18 +13,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
#include "ngraph/op/fused/elu.hpp"
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/builder/autobroadcast.hpp"
|
||||
#include "ngraph/builder/make_constant.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/exp.hpp"
|
||||
#include "ngraph/op/maximum.hpp"
|
||||
#include "ngraph/op/minimum.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/elu.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -32,7 +25,7 @@ using namespace ngraph;
|
||||
constexpr NodeTypeInfo op::Elu::type_info;
|
||||
|
||||
op::Elu::Elu(const Output<Node>& data, const double alpha)
|
||||
: FusedOp({data})
|
||||
: Op({data})
|
||||
, m_alpha{alpha}
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
@ -44,21 +37,10 @@ bool ngraph::op::v0::Elu::visit_attributes(AttributeVisitor& visitor)
|
||||
return true;
|
||||
}
|
||||
|
||||
NodeVector op::Elu::decompose_op() const
|
||||
void op::v0::Elu::validate_and_infer_types()
|
||||
{
|
||||
auto data = input_value(0);
|
||||
shared_ptr<Node> alpha_node =
|
||||
make_shared<op::Constant>(data.get_element_type(), Shape{}, vector<double>{m_alpha});
|
||||
|
||||
alpha_node = builder::numpy_broadcast(alpha_node, data.get_shape());
|
||||
|
||||
shared_ptr<ngraph::Node> zero_node =
|
||||
builder::make_constant(data.get_element_type(), data.get_shape(), 0);
|
||||
|
||||
return {make_shared<ngraph::op::Maximum>(data, zero_node) +
|
||||
alpha_node *
|
||||
make_shared<ngraph::op::Exp>(make_shared<ngraph::op::Minimum>(data, zero_node)) -
|
||||
alpha_node};
|
||||
set_output_size(1);
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::Elu::clone_with_new_inputs(const OutputVector& new_args) const
|
@ -18,7 +18,6 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -30,7 +29,7 @@ namespace ngraph
|
||||
/// x < 0 => f(x) = alpha * (exp(x) - 1.)
|
||||
/// x >= 0 => f(x) = x
|
||||
///
|
||||
class NGRAPH_API Elu : public ngraph::op::util::FusedOp
|
||||
class NGRAPH_API Elu : public ngraph::op::Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Elu", 0};
|
||||
@ -43,7 +42,7 @@ namespace ngraph
|
||||
Elu(const Output<Node>& data, const double alpha);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual NodeVector decompose_op() const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
@ -57,6 +57,7 @@
|
||||
#include "ngraph/op/detection_output.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/dot.hpp"
|
||||
#include "ngraph/op/elu.hpp"
|
||||
#include "ngraph/op/embedding_lookup.hpp"
|
||||
#include "ngraph/op/embedding_segments_sum.hpp"
|
||||
#include "ngraph/op/embeddingbag_offsets_sum.hpp"
|
||||
@ -84,7 +85,6 @@
|
||||
#include "ngraph/op/fused/conv_fused.hpp"
|
||||
#include "ngraph/op/fused/crossentropy.hpp"
|
||||
#include "ngraph/op/fused/depth_to_space.hpp"
|
||||
#include "ngraph/op/fused/elu.hpp"
|
||||
#include "ngraph/op/fused/fake_quantize.hpp"
|
||||
#include "ngraph/op/fused/gelu.hpp"
|
||||
#include "ngraph/op/fused/gemm.hpp"
|
||||
|
38
ngraph/src/ngraph/runtime/reference/elu.hpp
Normal file
38
ngraph/src/ngraph/runtime/reference/elu.hpp
Normal file
@ -0,0 +1,38 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void elu(const T* arg, T* out, size_t count, double alpha)
|
||||
{
|
||||
for (size_t i = 0; i < count; i++)
|
||||
{
|
||||
out[i] = arg[i] < 0 ? alpha * (std::exp(arg[i]) - 1.0) : arg[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -53,6 +53,7 @@
|
||||
#include "ngraph/runtime/reference/cum_sum.hpp"
|
||||
#include "ngraph/runtime/reference/dequantize.hpp"
|
||||
#include "ngraph/runtime/reference/dot.hpp"
|
||||
#include "ngraph/runtime/reference/elu.hpp"
|
||||
#include "ngraph/runtime/reference/embedding_lookup.hpp"
|
||||
#include "ngraph/runtime/reference/erf.hpp"
|
||||
#include "ngraph/runtime/reference/exp.hpp"
|
||||
@ -312,6 +313,17 @@ protected:
|
||||
element_count);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::Elu:
|
||||
{
|
||||
const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
|
||||
|
||||
size_t element_count = shape_size(node.get_output_shape(0));
|
||||
reference::elu<T>(args[0]->get_data_ptr<const T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
element_count,
|
||||
elu_node->get_alpha());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::AvgPool:
|
||||
{
|
||||
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
|
||||
@ -1429,7 +1441,6 @@ protected:
|
||||
case OP_TYPEID::DynPad:
|
||||
case OP_TYPEID::DynReplaceSlice:
|
||||
case OP_TYPEID::DynSlice:
|
||||
case OP_TYPEID::Elu:
|
||||
case OP_TYPEID::FakeQuantize:
|
||||
case OP_TYPEID::Gather:
|
||||
case OP_TYPEID::Gelu:
|
||||
|
Loading…
Reference in New Issue
Block a user