Use LogSoftmax-5 in the onnx_importer (#2602)
This commit is contained in:
parent
8a1653b0d1
commit
3688ff4c51
@ -31,6 +31,12 @@ namespace ngraph
|
|||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
namespace set_13
|
||||||
|
{
|
||||||
|
OutputVector log_softmax(const Node& node);
|
||||||
|
|
||||||
|
} // namespace set_1
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
} // namespace onnx_import
|
} // namespace onnx_import
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "log_softmax.hpp"
|
#include "log_softmax.hpp"
|
||||||
|
#include "ngraph/builder/reshape.hpp"
|
||||||
#include "ngraph/validation_util.hpp"
|
#include "ngraph/validation_util.hpp"
|
||||||
#include "onnx_import/default_opset.hpp"
|
#include "onnx_import/default_opset.hpp"
|
||||||
|
|
||||||
@ -24,25 +25,82 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace onnx_import
|
namespace onnx_import
|
||||||
{
|
{
|
||||||
namespace op
|
namespace detail
|
||||||
{
|
{
|
||||||
namespace set_1
|
std::shared_ptr<ngraph::Node> onnx_logsoftmax(const Output<ngraph::Node> data,
|
||||||
|
const int64_t axis)
|
||||||
{
|
{
|
||||||
OutputVector log_softmax(const Node& node)
|
const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
|
||||||
|
|
||||||
|
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||||
|
const auto max =
|
||||||
|
std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
|
||||||
|
|
||||||
|
const auto data_minus_max =
|
||||||
|
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
||||||
|
|
||||||
|
const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
|
||||||
|
if (data.get_partial_shape().is_static())
|
||||||
|
{
|
||||||
|
return ngraph::builder::opset1::reshape(result, data.get_shape());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||||
|
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)
|
||||||
{
|
{
|
||||||
OutputVector inputs{node.get_ng_inputs()};
|
OutputVector inputs{node.get_ng_inputs()};
|
||||||
const auto data = inputs.at(0);
|
const auto data = inputs.at(0);
|
||||||
const auto data_rank = data.get_partial_shape().rank();
|
const auto data_rank = data.get_partial_shape().rank();
|
||||||
|
|
||||||
const auto axis = node.get_attribute_value<int64_t>("axis", 1);
|
NGRAPH_CHECK(data_rank.is_static(),
|
||||||
|
"ONNX Softmax data rank needs to be known (static)");
|
||||||
|
|
||||||
|
const auto axis = node.get_attribute_value<int64_t>("axis", DEFAULT_AXIS);
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> result;
|
||||||
|
switch (data_rank.get_length())
|
||||||
|
{
|
||||||
|
case 0:
|
||||||
|
{
|
||||||
|
result = default_opset::Constant::create(data.get_element_type(), Shape{}, {1});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
{
|
||||||
|
// checks if the axis belongs to the allowed values set (-1 and 0 for 1D)
|
||||||
|
ngraph::normalize_axis(node.get_description(), axis, data_rank);
|
||||||
|
result = std::make_shared<default_opset::LogSoftmax>(data, 0);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
{
|
||||||
const auto normalized_axis =
|
const auto normalized_axis =
|
||||||
ngraph::normalize_axis(node.get_description(), axis, data_rank);
|
ngraph::normalize_axis(node.get_description(), axis, data_rank);
|
||||||
|
|
||||||
const auto softmax =
|
result = onnx_logsoftmax(data, normalized_axis);
|
||||||
std::make_shared<default_opset::Softmax>(data, normalized_axis);
|
break;
|
||||||
return {std::make_shared<default_opset::Log>(softmax)};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return {result};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace op
|
||||||
|
{
|
||||||
|
namespace set_1
|
||||||
|
{
|
||||||
|
OutputVector log_softmax(const Node& node) { return detail::log_softmax(node, 1); }
|
||||||
|
} // namespace set_1
|
||||||
|
|
||||||
|
namespace set_13
|
||||||
|
{
|
||||||
|
OutputVector log_softmax(const Node& node) { return detail::log_softmax(node, -1); }
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -360,6 +360,7 @@ namespace ngraph
|
|||||||
REGISTER_OPERATOR("Less", 1, less);
|
REGISTER_OPERATOR("Less", 1, less);
|
||||||
REGISTER_OPERATOR("Log", 1, log);
|
REGISTER_OPERATOR("Log", 1, log);
|
||||||
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
|
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
|
||||||
|
REGISTER_OPERATOR("LogSoftmax", 13, log_softmax);
|
||||||
// REGISTER_OPERATOR("Loop", 1, loop); // Loop operator disabled for the 2021.1 release
|
// REGISTER_OPERATOR("Loop", 1, loop); // Loop operator disabled for the 2021.1 release
|
||||||
REGISTER_OPERATOR("LpNormalization", 1, lp_norm);
|
REGISTER_OPERATOR("LpNormalization", 1, lp_norm);
|
||||||
REGISTER_OPERATOR("LRN", 1, lrn);
|
REGISTER_OPERATOR("LRN", 1, lrn);
|
||||||
|
@ -91,7 +91,6 @@ xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that i
|
|||||||
xfail_issue_35929 = xfail_test(reason="RuntimeError: Incorrect precision f64!")
|
xfail_issue_35929 = xfail_test(reason="RuntimeError: Incorrect precision f64!")
|
||||||
xfail_issue_35930 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: "
|
xfail_issue_35930 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: "
|
||||||
"Required attribute 'to' is missing.")
|
"Required attribute 'to' is missing.")
|
||||||
xfail_issue_35932 = xfail_test(reason="Assertion error - logsoftmax results mismatch")
|
|
||||||
xfail_issue_36437 = xfail_test(reason="RuntimeError: Cannot find blob with name: <value>")
|
xfail_issue_36437 = xfail_test(reason="RuntimeError: Cannot find blob with name: <value>")
|
||||||
xfail_issue_36476 = xfail_test(reason="RuntimeError: [NOT_IMPLEMENTED] Input image format U32 is "
|
xfail_issue_36476 = xfail_test(reason="RuntimeError: [NOT_IMPLEMENTED] Input image format U32 is "
|
||||||
"not supported yet...")
|
"not supported yet...")
|
||||||
|
@ -344,10 +344,7 @@ tests_expected_to_fail = [
|
|||||||
(xfail_issue_38091,
|
(xfail_issue_38091,
|
||||||
"OnnxBackendNodeModelTest.test_round_cpu",
|
"OnnxBackendNodeModelTest.test_round_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_mvn_cpu",
|
"OnnxBackendNodeModelTest.test_mvn_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_elu_example_cpu",
|
"OnnxBackendNodeModelTest.test_elu_example_cpu"),
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_default_axis_cpu"),
|
|
||||||
(xfail_issue_35929,
|
(xfail_issue_35929,
|
||||||
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_broadcast_cpu",
|
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_broadcast_cpu",
|
||||||
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu",
|
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu",
|
||||||
|
@ -24,8 +24,7 @@ from tests.runtime import get_runtime
|
|||||||
from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
|
from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
|
||||||
from tests import (xfail_issue_35929,
|
from tests import (xfail_issue_35929,
|
||||||
xfail_issue_34323,
|
xfail_issue_34323,
|
||||||
xfail_issue_35930,
|
xfail_issue_35930)
|
||||||
xfail_issue_35932)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -285,7 +284,6 @@ def test_softmax():
|
|||||||
ng_results = run_node(node, [data])
|
ng_results = run_node(node, [data])
|
||||||
|
|
||||||
|
|
||||||
@xfail_issue_35932
|
|
||||||
def test_logsoftmax():
|
def test_logsoftmax():
|
||||||
def logsoftmax_2d(x):
|
def logsoftmax_2d(x):
|
||||||
max_x = np.max(x, axis=1).reshape((-1, 1))
|
max_x = np.max(x, axis=1).reshape((-1, 1))
|
||||||
|
39
ngraph/test/models/onnx/logsoftmax13_1D.prototxt
Normal file
39
ngraph/test/models/onnx/logsoftmax13_1D.prototxt
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "y"
|
||||||
|
op_type: "LogSoftmax"
|
||||||
|
}
|
||||||
|
name: "LogSoftmax test"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
45
ngraph/test/models/onnx/logsoftmax13_2D.prototxt
Normal file
45
ngraph/test/models/onnx/logsoftmax13_2D.prototxt
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "x"
|
||||||
|
output: "y"
|
||||||
|
op_type: "LogSoftmax"
|
||||||
|
}
|
||||||
|
name: "LogSoftmax test"
|
||||||
|
input {
|
||||||
|
name: "x"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
37
ngraph/test/models/onnx/logsoftmax_0D.prototxt
Normal file
37
ngraph/test/models/onnx/logsoftmax_0D.prototxt
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "y"
|
||||||
|
op_type: "LogSoftmax"
|
||||||
|
}
|
||||||
|
name: "LogSoftmax test"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 1
|
||||||
|
}
|
44
ngraph/test/models/onnx/logsoftmax_1D.prototxt
Normal file
44
ngraph/test/models/onnx/logsoftmax_1D.prototxt
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "y"
|
||||||
|
op_type: "LogSoftmax"
|
||||||
|
attribute {
|
||||||
|
name: "axis"
|
||||||
|
i: 0
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "LogSoftmax test"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 1
|
||||||
|
}
|
@ -2719,3 +2719,55 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
|
|||||||
test_case.add_expected_output<float>(shape, output);
|
test_case.add_expected_output<float>(shape, output);
|
||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax_0D)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_0D.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input<float>({3.141592});
|
||||||
|
test_case.add_expected_output<float>({0.0});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax_1D)
|
||||||
|
{
|
||||||
|
const auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax_1D.prototxt"));
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
|
||||||
|
test_case.add_input<float>({-1.0f, 0.0f, 1.0f});
|
||||||
|
test_case.add_expected_output<float>(Shape{3}, {-2.4076061, -1.407606, -0.407606});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_1D)
|
||||||
|
{
|
||||||
|
const auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_1D.prototxt"));
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
|
||||||
|
test_case.add_input<float>({-1.0f, 0.0f, 1.0f});
|
||||||
|
test_case.add_expected_output<float>(Shape{3}, {-2.4076061, -1.407606, -0.407606});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D)
|
||||||
|
{
|
||||||
|
const auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt"));
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
|
||||||
|
test_case.add_input<float>({0.0f, 1.0f, 2.0f, 3.0f, 10000, 10001, 10002, 10003});
|
||||||
|
test_case.add_expected_output<float>(Shape{2, 4},
|
||||||
|
{-3.4401896,
|
||||||
|
-2.4401896,
|
||||||
|
-1.4401896,
|
||||||
|
-0.44018966,
|
||||||
|
-3.4401896,
|
||||||
|
-2.4401896,
|
||||||
|
-1.4401896,
|
||||||
|
-0.44018966});
|
||||||
|
test_case.run_with_tolerance_as_fp();
|
||||||
|
}
|
||||||
|
@ -69,6 +69,7 @@ bool_const_op
|
|||||||
onnx_model_tile
|
onnx_model_tile
|
||||||
onnx_model_tile_static
|
onnx_model_tile_static
|
||||||
onnx_model_softmax_0D
|
onnx_model_softmax_0D
|
||||||
|
onnx_model_logsoftmax_0D
|
||||||
builder_opset1_collapse_none
|
builder_opset1_collapse_none
|
||||||
|
|
||||||
# nGraph function's output number 0 was not found in the CNNNetwork built from it.
|
# nGraph function's output number 0 was not found in the CNNNetwork built from it.
|
||||||
|
@ -141,3 +141,7 @@ lstm_cell_bias_peepholes_clip_input_forget
|
|||||||
|
|
||||||
# unsupported element type f16
|
# unsupported element type f16
|
||||||
INTERPRETER.ctc_greedy_decoder_f16
|
INTERPRETER.ctc_greedy_decoder_f16
|
||||||
|
|
||||||
|
# LogSoftmax's reference implementation doesn't handle scalar input properly
|
||||||
|
onnx_model_logsoftmax_0D
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user