Add LogSoftmax-5 to MO and ngraph (#2409)

Co-authored-by: Evgeny Lazarev <evgeny.lazarev@intel.com>
This commit is contained in:
Maxim Vafin 2020-10-20 13:40:06 +03:00 committed by GitHub
parent 83670dd5cb
commit a405546054
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 847 additions and 240 deletions

View File

@ -0,0 +1,26 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API LogSoftmaxDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief LogSoftmaxDecomposition transformation into sub-graph x - log(reduce_sum(exp(x), axis)).
*/
class ngraph::pass::LogSoftmaxDecomposition : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
LogSoftmaxDecomposition();
};

View File

@ -41,6 +41,7 @@
#include "transformations/op_conversions/reduce_l1_decomposition.hpp"
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
#include "transformations/op_conversions/hswish_decomposition.hpp"
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
@ -78,6 +79,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
decomp->add_matcher<ngraph::pass::HSwishDecomposition>();
decomp->add_matcher<ngraph::pass::LogSoftmaxDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();

View File

@ -0,0 +1,44 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
#include <memory>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::LogSoftmaxDecomposition, "LogSoftmaxDecomposition", 0);
ngraph::pass::LogSoftmaxDecomposition::LogSoftmaxDecomposition() {
// Decomposes LogSoftmax(x, axis) op into sub-graph x - log(reduce_sum(exp(x), axis))
auto log_softmax = ngraph::pattern::wrap_type<opset5::LogSoftmax>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto log_softmax_node = std::dynamic_pointer_cast<ngraph::opset5::LogSoftmax>(pattern_to_output.at(log_softmax).get_node_shared_ptr());
if (m_transformation_callback(log_softmax_node)) {
return false;
}
auto axis1 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() });
auto axis2 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() });
auto max = std::make_shared<ngraph::opset5::ReduceMax>(log_softmax_node->input_value(0), axis1, true);
auto sub = std::make_shared<ngraph::opset5::Subtract>(log_softmax_node->input_value(0), max);
auto exp = std::make_shared<ngraph::opset5::Exp>(sub);
auto sum = std::make_shared<ngraph::opset5::ReduceSum>(exp, axis2, true);
auto log = std::make_shared<ngraph::opset5::Log>(sum);
auto sub_end = std::make_shared<ngraph::opset5::Subtract>(sub, log);
sub_end->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(log_softmax_node, { axis1, axis2, max, sub, exp, sum, log, sub_end });
ngraph::replace_node(m.get_match_root(), sub_end);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(log_softmax, "LogSoftmaxDecomposition");
register_matcher(m, callback);
}

View File

@ -17,7 +17,7 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) {
</port>
</output>
</layer>
<layer name="softmax" id="1" type="LogSoftmax" version="opset5">
<layer name="log_softmax" id="1" type="LogSoftmax" version="opset5">
<data axis="1"/>
<input>
<port id="1" precision="FP32">
@ -47,7 +47,7 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) {
</edges>
</net>
)V0G0N";
std::string modelV5 = R"V0G0N(
std::string model_ref = R"V0G0N(
<net name="Network" version="5" precision="FP32" batch="1">
<layers>
<layer name="in1" type="Input" precision="FP32" id="0">
@ -58,16 +58,153 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) {
</port>
</output>
</layer>
<layer name="softmax" id="1" type="LogSoftmax" precision="FP32">
<data axis="1"/>
<layer id="1" name="LogSoftmax/ReduceMax_axis" type="Const">
<output>
<port id="1" precision="I64">
<dim>1</dim>
</port>
</output>
<blobs>
<custom offset="0" size="8"/>
</blobs>
</layer>
<layer id="2" name="LogSoftmax/ReduceMax" type="ReduceMax">
<data keep_dims="True"/>
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
</port>
<port id="1">
<dim>1</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer id="3" name="LogSoftmax/Neg1" type="Power">
<data power="1.0" scale="-1.0" shift="0.0"/>
<input>
<port id="0">
<dim>1</dim>
<dim>1</dim>
</port>
</input>
<output>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer id="4" name="LogSoftmax/Sub/first" type="Eltwise">
<data operation="sum"/>
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
</port>
<port id="1">
<dim>1</dim>
<dim>1</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>1000</dim>
</port>
</output>
</layer>
<layer id="5" name="LogSoftmax/Exp" type="Exp">
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
</port>
</input>
<output>
<port id="2">
<port id="1" precision="FP32">
<dim>1</dim>
<dim>1000</dim>
</port>
</output>
</layer>
<layer id="6" name="LogSoftmax/ReduceSum_axis" type="Const">
<output>
<port id="1" precision="I64">
<dim>1</dim>
</port>
</output>
<blobs>
<custom offset="8" size="8"/>
</blobs>
</layer>
<layer id="7" name="LogSoftmax/ReduceSum" type="ReduceSum">
<data keep_dims="True"/>
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
</port>
<port id="1">
<dim>1</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer id="8" name="LogSoftmax/Log" type="Log">
<input>
<port id="0">
<dim>1</dim>
<dim>1</dim>
</port>
</input>
<output>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer id="9" name="LogSoftmax/Neg2" type="Power">
<data power="1.0" scale="-1.0" shift="0.0"/>
<input>
<port id="0">
<dim>1</dim>
<dim>1</dim>
</port>
</input>
<output>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer id="10" name="log_softmax" type="Eltwise">
<data operation="sum"/>
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
</port>
<port id="1">
<dim>1</dim>
<dim>1</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>1000</dim>
</port>
@ -75,10 +212,25 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) {
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
<edge from-layer="0" from-port="0" to-layer="4" to-port="0"/>
<edge from-layer="1" from-port="1" to-layer="2" to-port="1"/>
<edge from-layer="2" from-port="2" to-layer="3" to-port="0"/>
<edge from-layer="3" from-port="1" to-layer="4" to-port="1"/>
<edge from-layer="4" from-port="2" to-layer="5" to-port="0"/>
<edge from-layer="5" from-port="1" to-layer="7" to-port="0"/>
<edge from-layer="6" from-port="1" to-layer="7" to-port="1"/>
<edge from-layer="7" from-port="2" to-layer="8" to-port="0"/>
<edge from-layer="4" from-port="2" to-layer="10" to-port="0"/>
<edge from-layer="8" from-port="1" to-layer="9" to-port="0"/>
<edge from-layer="9" from-port="1" to-layer="10" to-port="1"/>
</edges>
</net>
)V0G0N";
compareIRs(model, modelV5, 0);
compareIRs(model, model_ref, 16, [](Blob::Ptr& weights) {
auto* data = reinterpret_cast<int64_t*>(weights->buffer().as<int8_t*>());
data[0] = 1;
data[1] = 1;
});
}

View File

@ -0,0 +1,52 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/log_softmax_decomposition.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, LogSoftmaxDecomposition) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2});
auto log_softmax = std::make_shared<ngraph::opset5::LogSoftmax>(data, 1);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{log_softmax}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::LogSoftmaxDecomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f64, ngraph::Shape{3, 2});
auto axis1_const = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto max = std::make_shared<ngraph::opset5::ReduceMax>(input0, axis1_const, true);
auto sub = std::make_shared<ngraph::opset5::Subtract>(input0, max);
auto exp = std::make_shared<ngraph::opset5::Exp>(sub);
auto axis2_const = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto sum = std::make_shared<ngraph::opset5::ReduceSum>(exp, axis2_const, true);
auto log = std::make_shared<ngraph::opset5::Log>(sum);
auto sub_end = std::make_shared<ngraph::opset5::Subtract>(sub, log);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sub_end}, ngraph::ParameterVector{input0});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -150,7 +150,6 @@ extensions/front/kaldi/tanh_component_ext.py
extensions/front/kaldi/tdnn_component_replacer.py
extensions/front/LayerNorm.py
extensions/front/Log1p.py
extensions/front/LogSoftmax.py
extensions/front/MatMul_normalizer.py
extensions/front/Mish_fusion.py
extensions/front/MoveEmbeddedInputsToInputs.py
@ -390,6 +389,7 @@ extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py
extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/log_softmax_ext.py
extensions/front/tf/LookupTableInsert_ext.py
extensions/front/tf/LoopCond_ext.py
extensions/front/tf/lrn_ext.py
@ -905,6 +905,7 @@ mo/ops/expand_dims.py
mo/ops/fill.py
mo/ops/flatten.py
mo/ops/group_norm.py
mo/ops/log_softmax.py
mo/ops/lrn.py
mo/ops/lstmnonlinearity.py
mo/ops/memory.py

View File

@ -1,92 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.ReduceOps import ReduceMax, ReduceSum
from extensions.ops.activation_ops import Exp, Log
from extensions.ops.elementwise import Sub
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
class LogSoftmaxFrontReplacer(FrontReplacementOp):
"""
Replace LogSoftmax operation with ReduceMax + Sub + Exp + ReduceSum + Log + Sub.
More precisely, this transformation implements the following formulas of the calculation of LogSoftmax:
shifted_data = input_data - ReduceMax(input_data, axis), (1)
output = shifted_data - Log(ReduceSum(Exp(shifted_data), axis)).
These formulas is used to calculate LogSoftmax in implementation of TensorFlow (see
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/softmax_op_functor.h),
Kaldi (see https://github.com/kaldi-asr/kaldi/blob/master/src/cudamatrix/cu-kernels.cu),
MxNet (see https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/softmax-inl.h).
ONNX implements LogSoftmax according to formulas
flatten_data = Flatten(input_data, axis), (1')
shifted_data = flatten_data - ReduceMax(flatten_data, 1),
z = shifted_data - Log(ReduceSum(Exp(shifted_data), 1)),
output = Reshape(z, input_data.shape)
(see https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/codegen/mti/math/logsoftmax.cc,
https://github.com/microsoft/onnxruntime-tvm/blob/master/topi/include/topi/nn/softmax.h)
Formally speaking, the formula (1) is equivalent to the formula
output = Log(SoftMax(input_data, axis)) (2)
But LogSoftMax is calculated according to formula (1) for better numeric stability.
"""
op = "LogSoftmax"
enabled = True
def replace_op(self, graph: Graph, node: Node):
node_name = node.soft_get('name', node.id)
assert node.has_valid('axis'), 'The node "{}" does not have mandatory attribute "axis"'.format(node_name)
# Creating of ReduceMax -> Sub -> Exp block
first_sub_node = Sub(graph, {'name': node_name + '/Sub_/first_'}).create_node()
reduce_max_node = create_op_with_const_inputs(graph,
ReduceMax,
{1: int64_array([node.axis])},
op_attrs={'name': node_name + '/ReduceMax_', 'keep_dims': True})
reduce_max_node.out_port(0).connect(first_sub_node.in_port(1))
# Creating of Exp -> ReduceSum -> Log block
exp_node = Exp(graph, {'name': node_name + '/Exp_'}).create_node()
reduce_sum_node = create_op_with_const_inputs(graph,
ReduceSum,
{1: int64_array([node.axis])},
op_attrs={'name': node_name + '/ReduceSum_', 'keep_dims': True})
log_node = Log(graph, {'name': node_name + '/Log_'}).create_node()
first_sub_node.out_port(0).connect(exp_node.in_port(0))
exp_node.out_port(0).connect(reduce_sum_node.in_port(0))
reduce_sum_node.out_port(0).connect(log_node.in_port(0))
# Creating of the last Sub node
second_sub_node = Sub(graph, {}).create_node()
rename_nodes([(node, node_name + '/delete'), (second_sub_node, node_name)])
log_node.out_port(0).connect(second_sub_node.in_port(1))
first_sub_node.out_port(0).connect(second_sub_node.in_port(0))
# Correcting of input edges
source = node.in_port(0).get_source()
first_sub_node.in_port(0).connect(source)
reduce_max_node.in_port(0).connect(source)
return [second_sub_node.id]

View File

@ -1,86 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from generator import generator, generate
from extensions.front.LogSoftmax import LogSoftmaxFrontReplacer
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
graph_node_attributes = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'logsoftmax': {'type': None, 'kind': 'op', 'op': 'LogSoftmax', 'axis': -1},
'output': {'kind': 'op', 'type': 'Result', 'op': 'Result'},
}
graph_edges = [
('placeholder', 'logsoftmax'),
('logsoftmax', 'output'),
]
graph_ref_node_attributes = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'exp': {'type': 'Exp', 'kind': 'op', 'op': 'Exp'},
'reduce_sum': {'type': 'ReduceSum', 'kind': 'op', 'op': 'ReduceSum', 'keep_dims': True},
'reduce_max': {'type': 'ReduceMax', 'kind': 'op', 'op': 'ReduceMax', 'keep_dims': True},
'log': {'type': 'Log', 'kind': 'op', 'op': 'Log'},
'second_sub': {'type': 'Subtract', 'kind': 'op', 'op': 'Sub'},
'reduce_sum_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': int64_array([1])},
'reduce_max_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': int64_array([1])},
'first_sub': {'type': 'Subtract', 'kind': 'op', 'op': 'Sub'},
'output': {'kind': 'op', 'type': 'Result', 'op': 'Result'},
}
graph_ref_edges = [
('placeholder', 'reduce_max', {'in': 0, 'out': 0}),
('placeholder', 'first_sub', {'in': 0, 'out': 0}),
('reduce_max', 'first_sub', {'in': 1}),
('reduce_max_axis', 'reduce_max', {'in': 1}),
('first_sub', 'exp', {'in': 0, 'out': 0}),
('first_sub', 'second_sub', {'in': 0, 'out': 0}),
('exp', 'reduce_sum', {'in': 0}),
('reduce_sum_axis', 'reduce_sum', {'in': 1}),
('reduce_sum', 'log'),
('log', 'second_sub', {'in': 1}),
('second_sub', 'output'),
]
@generator
class LogSoftmaxReplacerTest(unittest.TestCase):
@generate(*[(-1, 'NCHW'), (-1, 'NHWC'), (0, 'NHWC'),
(0, 'NCHW'), (2, 'NCHW'), (2, 'NHWC'),
(-2, 'NHWC'), (-2, 'NCHW')])
def test_logsoftmax_replacer(self, axis, layout):
graph = build_graph(nodes_attrs=graph_node_attributes, edges=graph_edges)
graph_ref = build_graph(nodes_attrs=graph_ref_node_attributes,
edges=graph_ref_edges,
update_attributes={
'reduce_max_axis': {'value': int64_array([axis])},
'reduce_sum_axis': {'value': int64_array([axis])},
})
graph.graph['layout'] = layout
graph.stage = 'front'
LogSoftmaxFrontReplacer().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output')
self.assertTrue(flag, resp)

View File

@ -14,7 +14,7 @@
limitations under the License.
"""
from mo.ops.softmax import LogSoftmax
from mo.ops.log_softmax import LogSoftmax
from mo.front.extractor import FrontExtractorOp

View File

@ -36,10 +36,6 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
"""
enabled = True
def run_before(self):
from extensions.front.LogSoftmax import LogSoftmaxFrontReplacer
return [LogSoftmaxFrontReplacer]
def pattern(self):
return dict(nodes=[('flatten', dict(op='FlattenONNX'))],
edges=[])

View File

@ -18,7 +18,7 @@ from mo.graph.graph import Graph, Node, rename_nodes
from mo.ops.flatten import FlattenONNX
from mo.ops.reshape import Reshape
from mo.ops.shape import Shape
from mo.ops.softmax import LogSoftmax
from mo.ops.log_softmax import LogSoftmax
class LogSoftmaxONNXFrontReplacer(FrontReplacementOp):

View File

@ -16,7 +16,8 @@
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.softmax import LogSoftmaxONNX, SoftmaxONNX
from mo.ops.softmax import SoftmaxONNX
from mo.ops.log_softmax import LogSoftmaxONNX
class SoftmaxExtractor(FrontExtractorOp):

View File

@ -0,0 +1,32 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.extractor import FrontExtractorOp
from mo.ops.log_softmax import LogSoftmax
class LogSoftmaxExtractor(FrontExtractorOp):
op = 'LogSoftmax'
enabled = True
@classmethod
def extract(cls, node):
# the default value for the TF LogSoftmax is -1
axis = -1
if 'axis' in node.pb.attr:
axis = node.pb.attr['axis'].i
LogSoftmax.update_node_stat(node, {'axis': axis})
return cls.enabled

View File

@ -15,7 +15,7 @@
"""
from mo.front.extractor import FrontExtractorOp
from mo.ops.softmax import LogSoftmax, Softmax
from mo.ops.softmax import Softmax
class SoftmaxExtractor(FrontExtractorOp):
@ -30,17 +30,3 @@ class SoftmaxExtractor(FrontExtractorOp):
axis = node.pb.attr['axis'].i
Softmax.update_node_stat(node, {'axis': axis})
return cls.enabled
class LogSoftmaxExtractor(FrontExtractorOp):
op = 'LogSoftmax'
enabled = True
@classmethod
def extract(cls, node):
# the default value for the TF LogSoftmax is -1
axis = -1
if 'axis' in node.pb.attr:
axis = node.pb.attr['axis'].i
LogSoftmax.update_node_stat(node, {'axis': axis})
return cls.enabled

View File

@ -0,0 +1,67 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.graph.graph import Graph, Node
from mo.ops.op import Op, PermuteAttrs
class LogSoftmax(Op):
op = 'LogSoftmax'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': self.op,
'op': self.op,
'version': 'opset5',
'infer': self.infer,
'axis': 1,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
def supported_attrs(self):
return ['axis']
@staticmethod
def infer(node: Node):
assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\
'LogSoftmax node with id {} have more than one port connected'.format(node.id)
if node.axis < 0:
node.axis = len(node.in_port(0).data.get_shape()) + node.axis
assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\
'LogSoftmax node with id {} has wrong axis attribute'.format(node.id)
copy_shape_infer(node)
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
class LogSoftmaxONNX(Op):
op = 'LogSoftmaxONNX'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'infer': None,
'kind': 'op',
'axis': 1,
'type': None, # the operation will be replaced with a
# Reshape(LogSoftmax(FlattenONNX(x, axis), 1), x.shape) sub-graph
'op': self.op,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)

View File

@ -59,35 +59,3 @@ class SoftmaxONNX(Op):
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
class LogSoftmax(Op):
op = 'LogSoftmax'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'infer': None,
'kind': 'op',
'axis': 1,
'type': None, # the operation will be replaced with a x - Log(ReduceSum(Exp(x), axis)) sub-graph
'op': __class__.op,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
class LogSoftmaxONNX(Op):
op = 'LogSoftmaxONNX'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'infer': None,
'kind': 'op',
'axis': 1,
'type': None, # the operation will be replaced with a
# Reshape(LogSoftmax(FlattenONNX(x, axis), 1), x.shape) sub-graph
'op': __class__.op,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)

View File

@ -0,0 +1,62 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void log_softmax(const T* arg, T* out, const Shape& shape, const AxisSet& axes)
{
auto temp_shape = reduce(shape, axes, true);
auto temp_elements = shape_size(temp_shape);
auto temp_max = std::vector<T>(temp_elements, 0);
auto temp_sum = std::vector<T>(temp_elements, 0);
max(arg, temp_max.data(), shape, axes, true);
CoordinateTransform transform(shape);
CoordinateTransform temp_transform(temp_shape);
for (const Coordinate& coord : transform)
{
Coordinate temp_coord = reduce(coord, axes, true);
out[transform.index(coord)] = std::exp(
arg[transform.index(coord)] - temp_max[temp_transform.index(temp_coord)]);
}
sum(out, temp_sum.data(), shape, axes, true);
for (const Coordinate& coord : transform)
{
Coordinate temp_coord = reduce(coord, axes, true);
out[transform.index(coord)] =
(arg[transform.index(coord)] - temp_max[temp_transform.index(temp_coord)]) -
std::log(temp_sum[temp_transform.index(temp_coord)]);
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -294,6 +294,7 @@ set(MULTI_TEST_SRC
backend/group_convolution.in.cpp
backend/interpolate.in.cpp
backend/log.in.cpp
backend/log_softmax.in.cpp
backend/logical_or.in.cpp
backend/logical_xor.in.cpp
backend/lrn.in.cpp

View File

@ -0,0 +1,368 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// clang-format off
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#endif
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif
// clang-format on
#include "gtest/gtest.h"
#include "runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_1d_single_value)
{
Shape shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{0};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, 0), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis0)
{
Shape shape{2, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-10000., -10000., -10000., -10000., 0., 0., 0., 0.};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, 0), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis1)
{
Shape shape{2, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-3.4401896,
-2.4401896,
-1.4401897,
-0.4401897,
-3.4401896,
-2.4401896,
-1.4401897,
-0.4401897};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, 1), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis_neg1)
{
Shape shape{2, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-3.4401896,
-2.4401896,
-1.4401897,
-0.4401897,
-3.4401896,
-2.4401896,
-1.4401897,
-0.4401897};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, -1), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis_neg2)
{
Shape shape{2, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-10000., -10000., -10000., -10000., 0., 0., 0., 0.};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, -2), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_0)
{
Shape shape{3, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-12.0024818,
-12.0024818,
-12.0024818,
-12.0024818,
-12.0024818,
-12.0024818,
-6.00248181,
-6.00248181,
-6.00248181,
-6.00248181,
-6.00248181,
-6.00248181,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, 0), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_1)
{
Shape shape{3, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-3.04858735,
-3.04858735,
-3.04858735,
-0.04858735,
-0.04858735,
-0.04858735,
-3.04858735,
-3.04858735,
-3.04858735,
-0.04858735,
-0.04858735,
-0.04858735,
-3.04858735,
-3.04858735,
-3.04858735,
-0.04858735,
-0.04858735,
-0.04858735};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, 1), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_2)
{
Shape shape{3, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, 2), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_neg1)
{
Shape shape{3, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596,
-2.40760596,
-1.40760596,
-0.40760596};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, -1), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_neg2)
{
Shape shape{3, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-3.04858735,
-3.04858735,
-3.04858735,
-0.04858735,
-0.04858735,
-0.04858735,
-3.04858735,
-3.04858735,
-3.04858735,
-0.04858735,
-0.04858735,
-0.04858735,
-3.04858735,
-3.04858735,
-3.04858735,
-0.04858735,
-0.04858735,
-0.04858735};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, -2), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_neg3)
{
Shape shape{3, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, shape);
std::vector<float> expected_result{-12.0024818,
-12.0024818,
-12.0024818,
-12.0024818,
-12.0024818,
-12.0024818,
-6.00248181,
-6.00248181,
-6.00248181,
-6.00248181,
-6.00248181,
-6.00248181,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03,
-2.48181414e-03};
auto f = make_shared<Function>(make_shared<op::v5::LogSoftmax>(A, -3), ParameterVector{A});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}

View File

@ -85,6 +85,10 @@ namespace
ie_ops.insert(opset2.begin(), opset2.end());
auto& opset3 = get_opset3().get_type_info_set();
ie_ops.insert(opset3.begin(), opset3.end());
auto& opset4 = get_opset4().get_type_info_set();
ie_ops.insert(opset4.begin(), opset4.end());
auto& opset5 = get_opset5().get_type_info_set();
ie_ops.insert(opset5.begin(), opset5.end());
return ie_ops;
}
}

View File

@ -1130,6 +1130,13 @@ IE_CPU.onnx_resize11_scales_nearest_asymmetric_floor_dynamic_sizes
# Input data precision not supported. Expected float.
ctc_greedy_decoder_f16
# Wrong output when axis 0
IE_CPU.log_softmax_1d_single_value
IE_CPU.log_softmax_2d_axis0
IE_CPU.log_softmax_2d_axis_neg2
IE_CPU.log_softmax_3d_axis_0
IE_CPU.log_softmax_3d_axis_neg3
#-------------------------------------------------------------------------------
#
# Inference Engine GPU plugin excludes

View File

@ -62,6 +62,7 @@
#include "ngraph/runtime/reference/gather_tree.hpp"
#include "ngraph/runtime/reference/gru_cell.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/log_softmax.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/lstm_cell.hpp"
#include "ngraph/runtime/reference/matmul.hpp"
@ -874,6 +875,20 @@ protected:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LogSoftmax_v5:
{
const op::v5::LogSoftmax* log_softmax = static_cast<const op::v5::LogSoftmax*>(&node);
int64_t i_axis = log_softmax->get_axis();
if (i_axis < 0)
{
i_axis += args[0]->get_partial_shape().rank().get_length();
}
reference::log_softmax<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_output_shape(0),
AxisSet{(size_t)i_axis});
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);

View File

@ -57,4 +57,5 @@ NGRAPH_OP(GatherND, op::v5)
NGRAPH_OP(LSTMSequence, op::v5)
NGRAPH_OP(GRUSequence, op::v5)
NGRAPH_OP(RNNSequence, op::v5)
NGRAPH_OP(LogSoftmax, op::v5)
#undef ID_SUFFIX