From a405546054b031b043cfc5fc3fcef0d81fc5a1a5 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 20 Oct 2020 13:40:06 +0300 Subject: [PATCH] Add LogSoftmax-5 to MO and ngraph (#2409) Co-authored-by: Evgeny Lazarev --- .../log_softmax_decomposition.hpp | 26 ++ .../common_optimizations.cpp | 2 + .../log_softmax_decomposition.cpp | 44 +++ .../ngraph_reader/log_softmax_tests.cpp | 166 +++++++- .../log_softmax_decomposition_test.cpp | 52 +++ model-optimizer/automation/package_BOM.txt | 3 +- .../extensions/front/LogSoftmax.py | 92 ----- .../extensions/front/LogSoftmax_test.py | 86 ---- .../front/kaldi/logsoftmax_component_ext.py | 2 +- .../front/onnx/flattenONNX_to_reshape.py | 4 - .../onnx/logsoftmaxONNX_to_logsoftmax.py | 2 +- .../extensions/front/onnx/softmax_ext.py | 3 +- .../extensions/front/tf/log_softmax_ext.py | 32 ++ .../extensions/front/tf/softmax_ext.py | 16 +- model-optimizer/mo/ops/log_softmax.py | 67 ++++ model-optimizer/mo/ops/softmax.py | 32 -- .../ngraph/runtime/reference/log_softmax.hpp | 62 +++ ngraph/test/CMakeLists.txt | 1 + ngraph/test/backend/log_softmax.in.cpp | 368 ++++++++++++++++++ ngraph/test/runtime/ie/ie_executable.cpp | 4 + ngraph/test/runtime/ie/unit_test.manifest | 7 + .../runtime/interpreter/int_executable.hpp | 15 + .../runtime/interpreter/opset_int_tbl.hpp | 1 + 23 files changed, 847 insertions(+), 240 deletions(-) create mode 100644 inference-engine/src/transformations/include/transformations/op_conversions/log_softmax_decomposition.hpp create mode 100644 inference-engine/src/transformations/src/transformations/op_conversions/log_softmax_decomposition.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/log_softmax_decomposition_test.cpp delete mode 100644 model-optimizer/extensions/front/LogSoftmax.py delete mode 100644 model-optimizer/extensions/front/LogSoftmax_test.py create mode 100644 model-optimizer/extensions/front/tf/log_softmax_ext.py create mode 100644 model-optimizer/mo/ops/log_softmax.py create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/log_softmax.hpp create mode 100644 ngraph/test/backend/log_softmax.in.cpp diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/log_softmax_decomposition.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/log_softmax_decomposition.hpp new file mode 100644 index 00000000000..acbcf40f2d0 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/op_conversions/log_softmax_decomposition.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ngraph { +namespace pass { + + class TRANSFORMATIONS_API LogSoftmaxDecomposition; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief LogSoftmaxDecomposition transformation into sub-graph x - log(reduce_sum(exp(x), axis)). + */ +class ngraph::pass::LogSoftmaxDecomposition : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + LogSoftmaxDecomposition(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index f4e5df8600a..059faa72337 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -41,6 +41,7 @@ #include "transformations/op_conversions/reduce_l1_decomposition.hpp" #include "transformations/op_conversions/reduce_l2_decomposition.hpp" #include "transformations/op_conversions/hswish_decomposition.hpp" +#include "transformations/op_conversions/log_softmax_decomposition.hpp" #include #include @@ -78,6 +79,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); decomp->add_matcher(); decomp->add_matcher(); + decomp->add_matcher(); decomp->add_matcher(); decomp->add_matcher(); decomp->add_matcher(); diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/log_softmax_decomposition.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/log_softmax_decomposition.cpp new file mode 100644 index 00000000000..12c4d2535bd --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/op_conversions/log_softmax_decomposition.cpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/log_softmax_decomposition.hpp" + +#include + +#include +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::LogSoftmaxDecomposition, "LogSoftmaxDecomposition", 0); + +ngraph::pass::LogSoftmaxDecomposition::LogSoftmaxDecomposition() { + // Decomposes LogSoftmax(x, axis) op into sub-graph x - log(reduce_sum(exp(x), axis)) + auto log_softmax = ngraph::pattern::wrap_type(); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto log_softmax_node = std::dynamic_pointer_cast(pattern_to_output.at(log_softmax).get_node_shared_ptr()); + + if (m_transformation_callback(log_softmax_node)) { + return false; + } + + auto axis1 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() }); + auto axis2 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() }); + auto max = std::make_shared(log_softmax_node->input_value(0), axis1, true); + auto sub = std::make_shared(log_softmax_node->input_value(0), max); + auto exp = std::make_shared(sub); + auto sum = std::make_shared(exp, axis2, true); + auto log = std::make_shared(sum); + auto sub_end = std::make_shared(sub, log); + + sub_end->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info(log_softmax_node, { axis1, axis2, max, sub, exp, sum, log, sub_end }); + ngraph::replace_node(m.get_match_root(), sub_end); + return true; + }; + + auto m = std::make_shared(log_softmax, "LogSoftmaxDecomposition"); + register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/log_softmax_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/log_softmax_tests.cpp index f3a0a01d074..7b452f4f1bd 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reader/log_softmax_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/log_softmax_tests.cpp @@ -17,7 +17,7 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) { - + @@ -47,7 +47,7 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) { )V0G0N"; - std::string modelV5 = R"V0G0N( + std::string model_ref = R"V0G0N( @@ -58,16 +58,153 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) { - - + + + + 1 + + + + + + + + + + 1 + 1000 + 1 + + + + + 1 + 1 + + + + + + + + 1 + 1 + + + + + 1 + 1 + + + + + + + + 1 + 1000 + + + 1 + 1 + + + + + 1 + 1000 + + + + + + + 1 1000 - + + 1 + 1000 + + + + + + + 1 + + + + + + + + + + + 1 + 1000 + + + 1 + + + + + 1 + 1 + + + + + + + 1 + 1 + + + + + 1 + 1 + + + + + + + + 1 + 1 + + + + + 1 + 1 + + + + + + + + 1 + 1000 + + + 1 + 1 + + + + 1 1000 @@ -75,10 +212,25 @@ TEST_F(NGraphReaderTests, ReadLogSoftmaxNetwork) { - + + + + + + + + + + + + )V0G0N"; - compareIRs(model, modelV5, 0); + compareIRs(model, model_ref, 16, [](Blob::Ptr& weights) { + auto* data = reinterpret_cast(weights->buffer().as()); + data[0] = 1; + data[1] = 1; + }); } diff --git a/inference-engine/tests/functional/inference_engine/transformations/log_softmax_decomposition_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/log_softmax_decomposition_test.cpp new file mode 100644 index 00000000000..b6e5884fd20 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/log_softmax_decomposition_test.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, LogSoftmaxDecomposition) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2}); + auto log_softmax = std::make_shared(data, 1); + + f = std::make_shared(ngraph::NodeVector{log_softmax}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input0 = std::make_shared(ngraph::element::f64, ngraph::Shape{3, 2}); + auto axis1_const = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); + auto max = std::make_shared(input0, axis1_const, true); + auto sub = std::make_shared(input0, max); + auto exp = std::make_shared(sub); + auto axis2_const = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); + auto sum = std::make_shared(exp, axis2_const, true); + auto log = std::make_shared(sum); + auto sub_end = std::make_shared(sub, log); + + f_ref = std::make_shared(ngraph::NodeVector{sub_end}, ngraph::ParameterVector{input0}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index b3a599a17ce..4990bd90e6e 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -150,7 +150,6 @@ extensions/front/kaldi/tanh_component_ext.py extensions/front/kaldi/tdnn_component_replacer.py extensions/front/LayerNorm.py extensions/front/Log1p.py -extensions/front/LogSoftmax.py extensions/front/MatMul_normalizer.py extensions/front/Mish_fusion.py extensions/front/MoveEmbeddedInputsToInputs.py @@ -390,6 +389,7 @@ extensions/front/tf/identity_ext.py extensions/front/tf/identityN_to_identity.py extensions/front/tf/InterpolateTransposes.py extensions/front/tf/IteratorGetNext_ext.py +extensions/front/tf/log_softmax_ext.py extensions/front/tf/LookupTableInsert_ext.py extensions/front/tf/LoopCond_ext.py extensions/front/tf/lrn_ext.py @@ -905,6 +905,7 @@ mo/ops/expand_dims.py mo/ops/fill.py mo/ops/flatten.py mo/ops/group_norm.py +mo/ops/log_softmax.py mo/ops/lrn.py mo/ops/lstmnonlinearity.py mo/ops/memory.py diff --git a/model-optimizer/extensions/front/LogSoftmax.py b/model-optimizer/extensions/front/LogSoftmax.py deleted file mode 100644 index 4a7e2de28fa..00000000000 --- a/model-optimizer/extensions/front/LogSoftmax.py +++ /dev/null @@ -1,92 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -from extensions.ops.ReduceOps import ReduceMax, ReduceSum -from extensions.ops.activation_ops import Exp, Log -from extensions.ops.elementwise import Sub -from mo.front.common.partial_infer.utils import int64_array -from mo.front.common.replacement import FrontReplacementOp -from mo.front.tf.graph_utils import create_op_with_const_inputs -from mo.graph.graph import Graph, Node, rename_nodes - - -class LogSoftmaxFrontReplacer(FrontReplacementOp): - """ - Replace LogSoftmax operation with ReduceMax + Sub + Exp + ReduceSum + Log + Sub. - - More precisely, this transformation implements the following formulas of the calculation of LogSoftmax: - - shifted_data = input_data - ReduceMax(input_data, axis), (1) - output = shifted_data - Log(ReduceSum(Exp(shifted_data), axis)). - - These formulas is used to calculate LogSoftmax in implementation of TensorFlow (see - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/softmax_op_functor.h), - Kaldi (see https://github.com/kaldi-asr/kaldi/blob/master/src/cudamatrix/cu-kernels.cu), - MxNet (see https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/softmax-inl.h). - - ONNX implements LogSoftmax according to formulas - - flatten_data = Flatten(input_data, axis), (1') - shifted_data = flatten_data - ReduceMax(flatten_data, 1), - z = shifted_data - Log(ReduceSum(Exp(shifted_data), 1)), - output = Reshape(z, input_data.shape) - - (see https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/codegen/mti/math/logsoftmax.cc, - https://github.com/microsoft/onnxruntime-tvm/blob/master/topi/include/topi/nn/softmax.h) - - Formally speaking, the formula (1) is equivalent to the formula - output = Log(SoftMax(input_data, axis)) (2) - - But LogSoftMax is calculated according to formula (1) for better numeric stability. - """ - op = "LogSoftmax" - enabled = True - - def replace_op(self, graph: Graph, node: Node): - node_name = node.soft_get('name', node.id) - assert node.has_valid('axis'), 'The node "{}" does not have mandatory attribute "axis"'.format(node_name) - - # Creating of ReduceMax -> Sub -> Exp block - first_sub_node = Sub(graph, {'name': node_name + '/Sub_/first_'}).create_node() - reduce_max_node = create_op_with_const_inputs(graph, - ReduceMax, - {1: int64_array([node.axis])}, - op_attrs={'name': node_name + '/ReduceMax_', 'keep_dims': True}) - reduce_max_node.out_port(0).connect(first_sub_node.in_port(1)) - - # Creating of Exp -> ReduceSum -> Log block - exp_node = Exp(graph, {'name': node_name + '/Exp_'}).create_node() - reduce_sum_node = create_op_with_const_inputs(graph, - ReduceSum, - {1: int64_array([node.axis])}, - op_attrs={'name': node_name + '/ReduceSum_', 'keep_dims': True}) - log_node = Log(graph, {'name': node_name + '/Log_'}).create_node() - - first_sub_node.out_port(0).connect(exp_node.in_port(0)) - exp_node.out_port(0).connect(reduce_sum_node.in_port(0)) - reduce_sum_node.out_port(0).connect(log_node.in_port(0)) - - # Creating of the last Sub node - second_sub_node = Sub(graph, {}).create_node() - rename_nodes([(node, node_name + '/delete'), (second_sub_node, node_name)]) - log_node.out_port(0).connect(second_sub_node.in_port(1)) - first_sub_node.out_port(0).connect(second_sub_node.in_port(0)) - - # Correcting of input edges - source = node.in_port(0).get_source() - first_sub_node.in_port(0).connect(source) - reduce_max_node.in_port(0).connect(source) - - return [second_sub_node.id] diff --git a/model-optimizer/extensions/front/LogSoftmax_test.py b/model-optimizer/extensions/front/LogSoftmax_test.py deleted file mode 100644 index 5c5050cb3a8..00000000000 --- a/model-optimizer/extensions/front/LogSoftmax_test.py +++ /dev/null @@ -1,86 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import unittest - -from generator import generator, generate - -from extensions.front.LogSoftmax import LogSoftmaxFrontReplacer -from mo.front.common.partial_infer.utils import int64_array -from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph - -graph_node_attributes = { - 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'logsoftmax': {'type': None, 'kind': 'op', 'op': 'LogSoftmax', 'axis': -1}, - 'output': {'kind': 'op', 'type': 'Result', 'op': 'Result'}, -} - - -graph_edges = [ - ('placeholder', 'logsoftmax'), - ('logsoftmax', 'output'), -] - - -graph_ref_node_attributes = { - 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'exp': {'type': 'Exp', 'kind': 'op', 'op': 'Exp'}, - 'reduce_sum': {'type': 'ReduceSum', 'kind': 'op', 'op': 'ReduceSum', 'keep_dims': True}, - 'reduce_max': {'type': 'ReduceMax', 'kind': 'op', 'op': 'ReduceMax', 'keep_dims': True}, - 'log': {'type': 'Log', 'kind': 'op', 'op': 'Log'}, - 'second_sub': {'type': 'Subtract', 'kind': 'op', 'op': 'Sub'}, - 'reduce_sum_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': int64_array([1])}, - 'reduce_max_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': int64_array([1])}, - 'first_sub': {'type': 'Subtract', 'kind': 'op', 'op': 'Sub'}, - 'output': {'kind': 'op', 'type': 'Result', 'op': 'Result'}, -} - - -graph_ref_edges = [ - ('placeholder', 'reduce_max', {'in': 0, 'out': 0}), - ('placeholder', 'first_sub', {'in': 0, 'out': 0}), - ('reduce_max', 'first_sub', {'in': 1}), - ('reduce_max_axis', 'reduce_max', {'in': 1}), - ('first_sub', 'exp', {'in': 0, 'out': 0}), - ('first_sub', 'second_sub', {'in': 0, 'out': 0}), - ('exp', 'reduce_sum', {'in': 0}), - ('reduce_sum_axis', 'reduce_sum', {'in': 1}), - ('reduce_sum', 'log'), - ('log', 'second_sub', {'in': 1}), - ('second_sub', 'output'), -] - - -@generator -class LogSoftmaxReplacerTest(unittest.TestCase): - @generate(*[(-1, 'NCHW'), (-1, 'NHWC'), (0, 'NHWC'), - (0, 'NCHW'), (2, 'NCHW'), (2, 'NHWC'), - (-2, 'NHWC'), (-2, 'NCHW')]) - def test_logsoftmax_replacer(self, axis, layout): - graph = build_graph(nodes_attrs=graph_node_attributes, edges=graph_edges) - graph_ref = build_graph(nodes_attrs=graph_ref_node_attributes, - edges=graph_ref_edges, - update_attributes={ - 'reduce_max_axis': {'value': int64_array([axis])}, - 'reduce_sum_axis': {'value': int64_array([axis])}, - }) - graph.graph['layout'] = layout - graph.stage = 'front' - LogSoftmaxFrontReplacer().find_and_replace_pattern(graph) - (flag, resp) = compare_graphs(graph, graph_ref, 'output') - self.assertTrue(flag, resp) - diff --git a/model-optimizer/extensions/front/kaldi/logsoftmax_component_ext.py b/model-optimizer/extensions/front/kaldi/logsoftmax_component_ext.py index 8d4ddc6ff43..3f60ae944da 100644 --- a/model-optimizer/extensions/front/kaldi/logsoftmax_component_ext.py +++ b/model-optimizer/extensions/front/kaldi/logsoftmax_component_ext.py @@ -14,7 +14,7 @@ limitations under the License. """ -from mo.ops.softmax import LogSoftmax +from mo.ops.log_softmax import LogSoftmax from mo.front.extractor import FrontExtractorOp diff --git a/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py b/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py index bffd69d9a79..02c2d1f8b48 100644 --- a/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py +++ b/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py @@ -36,10 +36,6 @@ class FlattenONNXToReshape(FrontReplacementSubgraph): """ enabled = True - def run_before(self): - from extensions.front.LogSoftmax import LogSoftmaxFrontReplacer - return [LogSoftmaxFrontReplacer] - def pattern(self): return dict(nodes=[('flatten', dict(op='FlattenONNX'))], edges=[]) diff --git a/model-optimizer/extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py b/model-optimizer/extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py index 24387861eec..8a630460b20 100644 --- a/model-optimizer/extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py +++ b/model-optimizer/extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py @@ -18,7 +18,7 @@ from mo.graph.graph import Graph, Node, rename_nodes from mo.ops.flatten import FlattenONNX from mo.ops.reshape import Reshape from mo.ops.shape import Shape -from mo.ops.softmax import LogSoftmax +from mo.ops.log_softmax import LogSoftmax class LogSoftmaxONNXFrontReplacer(FrontReplacementOp): diff --git a/model-optimizer/extensions/front/onnx/softmax_ext.py b/model-optimizer/extensions/front/onnx/softmax_ext.py index 0a3c5245827..59d92d233a6 100644 --- a/model-optimizer/extensions/front/onnx/softmax_ext.py +++ b/model-optimizer/extensions/front/onnx/softmax_ext.py @@ -16,7 +16,8 @@ from mo.front.extractor import FrontExtractorOp from mo.front.onnx.extractors.utils import onnx_attr -from mo.ops.softmax import LogSoftmaxONNX, SoftmaxONNX +from mo.ops.softmax import SoftmaxONNX +from mo.ops.log_softmax import LogSoftmaxONNX class SoftmaxExtractor(FrontExtractorOp): diff --git a/model-optimizer/extensions/front/tf/log_softmax_ext.py b/model-optimizer/extensions/front/tf/log_softmax_ext.py new file mode 100644 index 00000000000..64c6e839b13 --- /dev/null +++ b/model-optimizer/extensions/front/tf/log_softmax_ext.py @@ -0,0 +1,32 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from mo.front.extractor import FrontExtractorOp +from mo.ops.log_softmax import LogSoftmax + + +class LogSoftmaxExtractor(FrontExtractorOp): + op = 'LogSoftmax' + enabled = True + + @classmethod + def extract(cls, node): + # the default value for the TF LogSoftmax is -1 + axis = -1 + if 'axis' in node.pb.attr: + axis = node.pb.attr['axis'].i + LogSoftmax.update_node_stat(node, {'axis': axis}) + return cls.enabled diff --git a/model-optimizer/extensions/front/tf/softmax_ext.py b/model-optimizer/extensions/front/tf/softmax_ext.py index 94c2b0ff4af..fc4461abce7 100644 --- a/model-optimizer/extensions/front/tf/softmax_ext.py +++ b/model-optimizer/extensions/front/tf/softmax_ext.py @@ -15,7 +15,7 @@ """ from mo.front.extractor import FrontExtractorOp -from mo.ops.softmax import LogSoftmax, Softmax +from mo.ops.softmax import Softmax class SoftmaxExtractor(FrontExtractorOp): @@ -30,17 +30,3 @@ class SoftmaxExtractor(FrontExtractorOp): axis = node.pb.attr['axis'].i Softmax.update_node_stat(node, {'axis': axis}) return cls.enabled - - -class LogSoftmaxExtractor(FrontExtractorOp): - op = 'LogSoftmax' - enabled = True - - @classmethod - def extract(cls, node): - # the default value for the TF LogSoftmax is -1 - axis = -1 - if 'axis' in node.pb.attr: - axis = node.pb.attr['axis'].i - LogSoftmax.update_node_stat(node, {'axis': axis}) - return cls.enabled diff --git a/model-optimizer/mo/ops/log_softmax.py b/model-optimizer/mo/ops/log_softmax.py new file mode 100644 index 00000000000..fe6d6e9055d --- /dev/null +++ b/model-optimizer/mo/ops/log_softmax.py @@ -0,0 +1,67 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + + +from mo.front.common.partial_infer.elemental import copy_shape_infer +from mo.graph.graph import Graph, Node +from mo.ops.op import Op, PermuteAttrs + + +class LogSoftmax(Op): + op = 'LogSoftmax' + enabled = False + + def __init__(self, graph: Graph, attrs: dict): + super().__init__(graph, { + 'type': self.op, + 'op': self.op, + 'version': 'opset5', + 'infer': self.infer, + 'axis': 1, + 'in_ports_count': 1, + 'out_ports_count': 1, + }, attrs) + + def supported_attrs(self): + return ['axis'] + + @staticmethod + def infer(node: Node): + assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\ + 'LogSoftmax node with id {} have more than one port connected'.format(node.id) + if node.axis < 0: + node.axis = len(node.in_port(0).data.get_shape()) + node.axis + assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\ + 'LogSoftmax node with id {} has wrong axis attribute'.format(node.id) + copy_shape_infer(node) + PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) + + +class LogSoftmaxONNX(Op): + op = 'LogSoftmaxONNX' + enabled = False + + def __init__(self, graph: Graph, attrs: dict): + super().__init__(graph, { + 'infer': None, + 'kind': 'op', + 'axis': 1, + 'type': None, # the operation will be replaced with a + # Reshape(LogSoftmax(FlattenONNX(x, axis), 1), x.shape) sub-graph + 'op': self.op, + 'in_ports_count': 1, + 'out_ports_count': 1, + }, attrs) diff --git a/model-optimizer/mo/ops/softmax.py b/model-optimizer/mo/ops/softmax.py index 8a6a2463db5..333a38061d9 100644 --- a/model-optimizer/mo/ops/softmax.py +++ b/model-optimizer/mo/ops/softmax.py @@ -59,35 +59,3 @@ class SoftmaxONNX(Op): 'in_ports_count': 1, 'out_ports_count': 1, }, attrs) - - -class LogSoftmax(Op): - op = 'LogSoftmax' - enabled = False - - def __init__(self, graph: Graph, attrs: dict): - super().__init__(graph, { - 'infer': None, - 'kind': 'op', - 'axis': 1, - 'type': None, # the operation will be replaced with a x - Log(ReduceSum(Exp(x), axis)) sub-graph - 'op': __class__.op, - 'in_ports_count': 1, - 'out_ports_count': 1, - }, attrs) - -class LogSoftmaxONNX(Op): - op = 'LogSoftmaxONNX' - enabled = False - - def __init__(self, graph: Graph, attrs: dict): - super().__init__(graph, { - 'infer': None, - 'kind': 'op', - 'axis': 1, - 'type': None, # the operation will be replaced with a - # Reshape(LogSoftmax(FlattenONNX(x, axis), 1), x.shape) sub-graph - 'op': __class__.op, - 'in_ports_count': 1, - 'out_ports_count': 1, - }, attrs) diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/log_softmax.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/log_softmax.hpp new file mode 100644 index 00000000000..6e1caba0c33 --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/log_softmax.hpp @@ -0,0 +1,62 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include "ngraph/coordinate_transform.hpp" +#include "ngraph/runtime/reference/max.hpp" +#include "ngraph/runtime/reference/sum.hpp" +#include "ngraph/shape_util.hpp" + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void log_softmax(const T* arg, T* out, const Shape& shape, const AxisSet& axes) + { + auto temp_shape = reduce(shape, axes, true); + auto temp_elements = shape_size(temp_shape); + auto temp_max = std::vector(temp_elements, 0); + auto temp_sum = std::vector(temp_elements, 0); + + max(arg, temp_max.data(), shape, axes, true); + + CoordinateTransform transform(shape); + CoordinateTransform temp_transform(temp_shape); + for (const Coordinate& coord : transform) + { + Coordinate temp_coord = reduce(coord, axes, true); + out[transform.index(coord)] = std::exp( + arg[transform.index(coord)] - temp_max[temp_transform.index(temp_coord)]); + } + + sum(out, temp_sum.data(), shape, axes, true); + + for (const Coordinate& coord : transform) + { + Coordinate temp_coord = reduce(coord, axes, true); + out[transform.index(coord)] = + (arg[transform.index(coord)] - temp_max[temp_transform.index(temp_coord)]) - + std::log(temp_sum[temp_transform.index(temp_coord)]); + } + } + } // namespace reference + } // namespace runtime +} // namespace ngraph diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index e39adcb2b9f..706690a2397 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -294,6 +294,7 @@ set(MULTI_TEST_SRC backend/group_convolution.in.cpp backend/interpolate.in.cpp backend/log.in.cpp + backend/log_softmax.in.cpp backend/logical_or.in.cpp backend/logical_xor.in.cpp backend/lrn.in.cpp diff --git a/ngraph/test/backend/log_softmax.in.cpp b/ngraph/test/backend/log_softmax.in.cpp new file mode 100644 index 00000000000..1304e815632 --- /dev/null +++ b/ngraph/test/backend/log_softmax.in.cpp @@ -0,0 +1,368 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// clang-format off +#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS +#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS +#endif + +#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS +#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS +#endif +// clang-format on + +#include "gtest/gtest.h" +#include "runtime/backend.hpp" +#include "ngraph/runtime/tensor.hpp" +#include "ngraph/ngraph.hpp" +#include "util/all_close.hpp" +#include "util/all_close_f.hpp" +#include "util/known_element_types.hpp" +#include "util/ndarray.hpp" +#include "util/test_control.hpp" +#include "util/test_tools.hpp" + +NGRAPH_SUPPRESS_DEPRECATED_START + +using namespace std; +using namespace ngraph; + +static string s_manifest = "${MANIFEST}"; + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_1d_single_value) +{ + Shape shape{1}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{1}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{0}; + + auto f = make_shared(make_shared(A, 0), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis0) +{ + Shape shape{2, 4}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-10000., -10000., -10000., -10000., 0., 0., 0., 0.}; + + auto f = make_shared(make_shared(A, 0), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis1) +{ + Shape shape{2, 4}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-3.4401896, + -2.4401896, + -1.4401897, + -0.4401897, + -3.4401896, + -2.4401896, + -1.4401897, + -0.4401897}; + + auto f = make_shared(make_shared(A, 1), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis_neg1) +{ + Shape shape{2, 4}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-3.4401896, + -2.4401896, + -1.4401897, + -0.4401897, + -3.4401896, + -2.4401896, + -1.4401897, + -0.4401897}; + + auto f = make_shared(make_shared(A, -1), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_2d_axis_neg2) +{ + Shape shape{2, 4}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-10000., -10000., -10000., -10000., 0., 0., 0., 0.}; + + auto f = make_shared(make_shared(A, -2), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_0) +{ + Shape shape{3, 2, 3}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-12.0024818, + -12.0024818, + -12.0024818, + -12.0024818, + -12.0024818, + -12.0024818, + -6.00248181, + -6.00248181, + -6.00248181, + -6.00248181, + -6.00248181, + -6.00248181, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03}; + + auto f = make_shared(make_shared(A, 0), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_1) +{ + Shape shape{3, 2, 3}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-3.04858735, + -3.04858735, + -3.04858735, + -0.04858735, + -0.04858735, + -0.04858735, + -3.04858735, + -3.04858735, + -3.04858735, + -0.04858735, + -0.04858735, + -0.04858735, + -3.04858735, + -3.04858735, + -3.04858735, + -0.04858735, + -0.04858735, + -0.04858735}; + + auto f = make_shared(make_shared(A, 1), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_2) +{ + Shape shape{3, 2, 3}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596}; + + auto f = make_shared(make_shared(A, 2), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_neg1) +{ + Shape shape{3, 2, 3}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596, + -2.40760596, + -1.40760596, + -0.40760596}; + + auto f = make_shared(make_shared(A, -1), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_neg2) +{ + Shape shape{3, 2, 3}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-3.04858735, + -3.04858735, + -3.04858735, + -0.04858735, + -0.04858735, + -0.04858735, + -3.04858735, + -3.04858735, + -3.04858735, + -0.04858735, + -0.04858735, + -0.04858735, + -3.04858735, + -3.04858735, + -3.04858735, + -0.04858735, + -0.04858735, + -0.04858735}; + + auto f = make_shared(make_shared(A, -2), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} + +NGRAPH_TEST(${BACKEND_NAME}, log_softmax_3d_axis_neg3) +{ + Shape shape{3, 2, 3}; + auto A = make_shared(element::f32, shape); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto a = backend->create_tensor(element::f32, shape); + copy_data(a, vector{-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8}); + auto result = backend->create_tensor(element::f32, shape); + + std::vector expected_result{-12.0024818, + -12.0024818, + -12.0024818, + -12.0024818, + -12.0024818, + -12.0024818, + -6.00248181, + -6.00248181, + -6.00248181, + -6.00248181, + -6.00248181, + -6.00248181, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03, + -2.48181414e-03}; + + auto f = make_shared(make_shared(A, -3), ParameterVector{A}); + auto handle = backend->compile(f); + handle->call_with_validate({result}, {a}); + EXPECT_TRUE(test::all_close(expected_result, read_vector(result))); +} diff --git a/ngraph/test/runtime/ie/ie_executable.cpp b/ngraph/test/runtime/ie/ie_executable.cpp index d3f959cd7a3..eba5a300e34 100644 --- a/ngraph/test/runtime/ie/ie_executable.cpp +++ b/ngraph/test/runtime/ie/ie_executable.cpp @@ -85,6 +85,10 @@ namespace ie_ops.insert(opset2.begin(), opset2.end()); auto& opset3 = get_opset3().get_type_info_set(); ie_ops.insert(opset3.begin(), opset3.end()); + auto& opset4 = get_opset4().get_type_info_set(); + ie_ops.insert(opset4.begin(), opset4.end()); + auto& opset5 = get_opset5().get_type_info_set(); + ie_ops.insert(opset5.begin(), opset5.end()); return ie_ops; } } diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index abee8acb08f..2b8e949435f 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1130,6 +1130,13 @@ IE_CPU.onnx_resize11_scales_nearest_asymmetric_floor_dynamic_sizes # Input data precision not supported. Expected float. ctc_greedy_decoder_f16 +# Wrong output when axis 0 +IE_CPU.log_softmax_1d_single_value +IE_CPU.log_softmax_2d_axis0 +IE_CPU.log_softmax_2d_axis_neg2 +IE_CPU.log_softmax_3d_axis_0 +IE_CPU.log_softmax_3d_axis_neg3 + #------------------------------------------------------------------------------- # # Inference Engine GPU plugin excludes diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index cc54b84f3ef..bd5db5ed66f 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -62,6 +62,7 @@ #include "ngraph/runtime/reference/gather_tree.hpp" #include "ngraph/runtime/reference/gru_cell.hpp" #include "ngraph/runtime/reference/log.hpp" +#include "ngraph/runtime/reference/log_softmax.hpp" #include "ngraph/runtime/reference/lrn.hpp" #include "ngraph/runtime/reference/lstm_cell.hpp" #include "ngraph/runtime/reference/matmul.hpp" @@ -874,6 +875,20 @@ protected: args[0]->get_data_ptr(), out[0]->get_data_ptr(), element_count); break; } + case OP_TYPEID::LogSoftmax_v5: + { + const op::v5::LogSoftmax* log_softmax = static_cast(&node); + int64_t i_axis = log_softmax->get_axis(); + if (i_axis < 0) + { + i_axis += args[0]->get_partial_shape().rank().get_length(); + } + reference::log_softmax(args[0]->get_data_ptr(), + out[0]->get_data_ptr(), + node.get_output_shape(0), + AxisSet{(size_t)i_axis}); + break; + } case OP_TYPEID::LRN: { const op::LRN* lrn = static_cast(&node); diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 4cfe6693f17..61fa35ddec2 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -57,4 +57,5 @@ NGRAPH_OP(GatherND, op::v5) NGRAPH_OP(LSTMSequence, op::v5) NGRAPH_OP(GRUSequence, op::v5) NGRAPH_OP(RNNSequence, op::v5) +NGRAPH_OP(LogSoftmax, op::v5) #undef ID_SUFFIX