Use LogSoftmax-5 in the onnx_importer (#2602)

This commit is contained in:
Tomasz Dołbniak 2020-10-21 10:50:16 +02:00 committed by GitHub
parent 8a1653b0d1
commit 3688ff4c51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 303 additions and 22 deletions

View File

@ -31,6 +31,12 @@ namespace ngraph
} // namespace set_1 } // namespace set_1
namespace set_13
{
OutputVector log_softmax(const Node& node);
} // namespace set_1
} // namespace op } // namespace op
} // namespace onnx_import } // namespace onnx_import

View File

@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "log_softmax.hpp" #include "log_softmax.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
#include "onnx_import/default_opset.hpp" #include "onnx_import/default_opset.hpp"
@ -24,25 +25,82 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace op namespace detail
{ {
namespace set_1 std::shared_ptr<ngraph::Node> onnx_logsoftmax(const Output<ngraph::Node> data,
const int64_t axis)
{ {
OutputVector log_softmax(const Node& node) const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
const auto max =
std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
const auto data_minus_max =
std::make_shared<default_opset::Subtract>(coerced_data, max);
const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
if (data.get_partial_shape().is_static())
{
return ngraph::builder::opset1::reshape(result, data.get_shape());
}
else
{
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
}
}
OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)
{ {
OutputVector inputs{node.get_ng_inputs()}; OutputVector inputs{node.get_ng_inputs()};
const auto data = inputs.at(0); const auto data = inputs.at(0);
const auto data_rank = data.get_partial_shape().rank(); const auto data_rank = data.get_partial_shape().rank();
const auto axis = node.get_attribute_value<int64_t>("axis", 1); NGRAPH_CHECK(data_rank.is_static(),
"ONNX Softmax data rank needs to be known (static)");
const auto axis = node.get_attribute_value<int64_t>("axis", DEFAULT_AXIS);
std::shared_ptr<ngraph::Node> result;
switch (data_rank.get_length())
{
case 0:
{
result = default_opset::Constant::create(data.get_element_type(), Shape{}, {1});
break;
}
case 1:
{
// checks if the axis belongs to the allowed values set (-1 and 0 for 1D)
ngraph::normalize_axis(node.get_description(), axis, data_rank);
result = std::make_shared<default_opset::LogSoftmax>(data, 0);
break;
}
default:
{
const auto normalized_axis = const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_rank); ngraph::normalize_axis(node.get_description(), axis, data_rank);
const auto softmax = result = onnx_logsoftmax(data, normalized_axis);
std::make_shared<default_opset::Softmax>(data, normalized_axis); break;
return {std::make_shared<default_opset::Log>(softmax)}; }
} }
return {result};
}
}
namespace op
{
namespace set_1
{
OutputVector log_softmax(const Node& node) { return detail::log_softmax(node, 1); }
} // namespace set_1
namespace set_13
{
OutputVector log_softmax(const Node& node) { return detail::log_softmax(node, -1); }
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op

View File

@ -360,6 +360,7 @@ namespace ngraph
REGISTER_OPERATOR("Less", 1, less); REGISTER_OPERATOR("Less", 1, less);
REGISTER_OPERATOR("Log", 1, log); REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax); REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LogSoftmax", 13, log_softmax);
// REGISTER_OPERATOR("Loop", 1, loop); // Loop operator disabled for the 2021.1 release // REGISTER_OPERATOR("Loop", 1, loop); // Loop operator disabled for the 2021.1 release
REGISTER_OPERATOR("LpNormalization", 1, lp_norm); REGISTER_OPERATOR("LpNormalization", 1, lp_norm);
REGISTER_OPERATOR("LRN", 1, lrn); REGISTER_OPERATOR("LRN", 1, lrn);

View File

@ -91,7 +91,6 @@ xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that i
xfail_issue_35929 = xfail_test(reason="RuntimeError: Incorrect precision f64!") xfail_issue_35929 = xfail_test(reason="RuntimeError: Incorrect precision f64!")
xfail_issue_35930 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: " xfail_issue_35930 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: "
"Required attribute 'to' is missing.") "Required attribute 'to' is missing.")
xfail_issue_35932 = xfail_test(reason="Assertion error - logsoftmax results mismatch")
xfail_issue_36437 = xfail_test(reason="RuntimeError: Cannot find blob with name: <value>") xfail_issue_36437 = xfail_test(reason="RuntimeError: Cannot find blob with name: <value>")
xfail_issue_36476 = xfail_test(reason="RuntimeError: [NOT_IMPLEMENTED] Input image format U32 is " xfail_issue_36476 = xfail_test(reason="RuntimeError: [NOT_IMPLEMENTED] Input image format U32 is "
"not supported yet...") "not supported yet...")

View File

@ -344,10 +344,7 @@ tests_expected_to_fail = [
(xfail_issue_38091, (xfail_issue_38091,
"OnnxBackendNodeModelTest.test_round_cpu", "OnnxBackendNodeModelTest.test_round_cpu",
"OnnxBackendNodeModelTest.test_mvn_cpu", "OnnxBackendNodeModelTest.test_mvn_cpu",
"OnnxBackendNodeModelTest.test_elu_example_cpu", "OnnxBackendNodeModelTest.test_elu_example_cpu"),
"OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu",
"OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu",
"OnnxBackendNodeModelTest.test_logsoftmax_default_axis_cpu"),
(xfail_issue_35929, (xfail_issue_35929,
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_broadcast_cpu", "OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_broadcast_cpu",
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu", "OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu",

View File

@ -24,8 +24,7 @@ from tests.runtime import get_runtime
from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
from tests import (xfail_issue_35929, from tests import (xfail_issue_35929,
xfail_issue_34323, xfail_issue_34323,
xfail_issue_35930, xfail_issue_35930)
xfail_issue_35932)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -285,7 +284,6 @@ def test_softmax():
ng_results = run_node(node, [data]) ng_results = run_node(node, [data])
@xfail_issue_35932
def test_logsoftmax(): def test_logsoftmax():
def logsoftmax_2d(x): def logsoftmax_2d(x):
max_x = np.max(x, axis=1).reshape((-1, 1)) max_x = np.max(x, axis=1).reshape((-1, 1))

View File

@ -0,0 +1,39 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "y"
op_type: "LogSoftmax"
}
name: "LogSoftmax test"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,45 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "LogSoftmax"
}
name: "LogSoftmax test"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,37 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "y"
op_type: "LogSoftmax"
}
name: "LogSoftmax test"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,44 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "y"
op_type: "LogSoftmax"
attribute {
name: "axis"
i: 0
type: INT
}
}
name: "LogSoftmax test"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -2719,3 +2719,55 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
test_case.add_expected_output<float>(shape, output); test_case.add_expected_output<float>(shape, output);
test_case.run(); test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax_0D)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_0D.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({3.141592});
test_case.add_expected_output<float>({0.0});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax_1D)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax_1D.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({-1.0f, 0.0f, 1.0f});
test_case.add_expected_output<float>(Shape{3}, {-2.4076061, -1.407606, -0.407606});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_1D)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_1D.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({-1.0f, 0.0f, 1.0f});
test_case.add_expected_output<float>(Shape{3}, {-2.4076061, -1.407606, -0.407606});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.0f, 1.0f, 2.0f, 3.0f, 10000, 10001, 10002, 10003});
test_case.add_expected_output<float>(Shape{2, 4},
{-3.4401896,
-2.4401896,
-1.4401896,
-0.44018966,
-3.4401896,
-2.4401896,
-1.4401896,
-0.44018966});
test_case.run_with_tolerance_as_fp();
}

View File

@ -69,6 +69,7 @@ bool_const_op
onnx_model_tile onnx_model_tile
onnx_model_tile_static onnx_model_tile_static
onnx_model_softmax_0D onnx_model_softmax_0D
onnx_model_logsoftmax_0D
builder_opset1_collapse_none builder_opset1_collapse_none
# nGraph function's output number 0 was not found in the CNNNetwork built from it. # nGraph function's output number 0 was not found in the CNNNetwork built from it.

View File

@ -141,3 +141,7 @@ lstm_cell_bias_peepholes_clip_input_forget
# unsupported element type f16 # unsupported element type f16
INTERPRETER.ctc_greedy_decoder_f16 INTERPRETER.ctc_greedy_decoder_f16
# LogSoftmax's reference implementation doesn't handle scalar input properly
onnx_model_logsoftmax_0D