Extend ONNX FE for operation Softmax-8 (#9189)

This commit is contained in:
Dawid Kożykowski 2021-12-15 21:40:43 +01:00 committed by GitHub
parent b643294300
commit 0b9158c2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 296 additions and 59 deletions

View File

@ -0,0 +1,56 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Softmax"
attribute {
name: "axis"
i: 1
type: INT
}
}
name: "test_softmax_axis_1"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 11
}

View File

@ -0,0 +1,56 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Softmax"
attribute {
name: "axis"
i: -1
type: INT
}
}
name: "test_softmax_axis_0"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 11
}

View File

@ -0,0 +1,56 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Softmax"
attribute {
name: "axis"
i: -1
type: INT
}
}
name: "test_softmax_axis_0"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -690,19 +690,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_1D) {
}
namespace {
// common input for all Softmax 3D test cases (Shape = {3,4,5})
// clang-format off
const std::vector<float> SOFTMAX_INPUT = {
2.75793882, -0.50841322, 0.82013929, -0.62409912, -0.96136118, 0.21004745, 1.38337255,
1.19030397, 2.0940445, -0.03551657, -0.78686039, 1.992782, 0.04300319, -0.29230777,
-0.56797112, -1.26732165, -0.61935399, 0.57670432, 0.92844898, 2.82469233,
2.75793882, -0.50841322, 0.82013929, -0.62409912, -0.96136118,
0.21004745, 1.38337255, 1.19030397, 2.0940445, -0.03551657,
-0.78686039, 1.992782, 0.04300319, -0.29230777, -0.56797112,
-1.26732165, -0.61935399, 0.57670432, 0.92844898, 2.82469233,
0.98721677, -0.05100663, -1.21178917, -0.17530157, 1.40051805, -0.13259761, -1.14313018,
0.2673723, -0.87996154, 1.29053106, 1.55, 0.8396538, 1.20729817, 0.23727845,
-0.89113606, -1.70909842, 0.26460363, -0.70566808, 2.383518, 1.07024615,
0.98721677, -0.05100663, -1.21178917, -0.17530157, 1.40051805,
-0.13259761, -1.14313018, 0.2673723, -0.87996154, 1.29053106,
1.55, 0.8396538, 1.20729817, 0.23727845, -0.89113606,
-1.70909842, 0.26460363, -0.70566808, 2.383518, 1.07024615,
-1.21722605, 0.82919357, 0.55765697, 0.12657686, 0.63432172, 0.75425957, -2.43721014,
-1.24478184, 2.65316853, 1.19509542, -0.95523998, 0.5149006, -0.01151649, 0.68327026,
-0.4589638, -0.46554745, 0.21055324, 0.39266729, 2.05098086, 1.83207919};
-1.21722605, 0.82919357, 0.55765697, 0.12657686, 0.63432172,
0.75425957, -2.43721014, -1.24478184, 2.65316853, 1.19509542,
-0.95523998, 0.5149006, -0.01151649, 0.68327026, -0.4589638,
-0.46554745, 0.21055324, 0.39266729, 2.05098086, 1.83207919};
} // namespace
// clang-format on
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_0) {
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_0.onnx"));
@ -710,19 +715,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_0) {
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>(SOFTMAX_INPUT);
// clang-format off
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.09683057, 0.00369363, 0.01394559, 0.00329012, 0.00234823, 0.00757665, 0.02449322,
0.02019284, 0.04985249, 0.00592694, 0.00279593, 0.04505148, 0.00641108, 0.00458466,
0.00348007, 0.00172928, 0.00330577, 0.01093237, 0.01554086, 0.10351497,
{0.09683057, 0.00369363, 0.01394559, 0.00329012, 0.00234823,
0.00757665, 0.02449322, 0.02019284, 0.04985249, 0.00592694,
0.00279593, 0.04505148, 0.00641108, 0.00458466, 0.00348007,
0.00172928, 0.00330577, 0.01093237, 0.01554086, 0.10351497,
0.01648154, 0.00583583, 0.00182802, 0.00515374, 0.02491679, 0.00537859, 0.00195794,
0.00802367, 0.00254737, 0.0223216, 0.02893419, 0.0142204, 0.02053893, 0.00778581,
0.00251907, 0.00111174, 0.00800149, 0.0030324, 0.06658917, 0.0179084,
0.01648154, 0.00583583, 0.00182802, 0.00515374, 0.02491679,
0.00537859, 0.00195794, 0.00802367, 0.00254737, 0.0223216,
0.02893419, 0.0142204, 0.02053893, 0.00778581, 0.00251907,
0.00111174, 0.00800149, 0.0030324, 0.06658917, 0.0179084,
0.00181811, 0.01407243, 0.01072611, 0.0069699, 0.01158077, 0.01305647, 0.00053677,
0.0017687, 0.08719896, 0.02028982, 0.00236265, 0.01027717, 0.0060709, 0.01216173,
0.00388087, 0.00385541, 0.00758048, 0.00909469, 0.04775123, 0.03836337});
0.00181811, 0.01407243, 0.01072611, 0.0069699, 0.01158077,
0.01305647, 0.00053677, 0.0017687, 0.08719896, 0.02028982,
0.00236265, 0.01027717, 0.0060709, 0.01216173, 0.00388087,
0.00385541, 0.00758048, 0.00909469, 0.04775123, 0.03836337});
// clang-format on
test_case.run(6);
}
@ -733,35 +743,113 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_1) {
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>(SOFTMAX_INPUT);
// clang-format off
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.22757064, 0.00868076, 0.03277484, 0.00773243, 0.0055188, 0.0178066, 0.05756383,
0.04745709, 0.11716303, 0.01392945, 0.00657097, 0.10587974, 0.01506727, 0.01077484,
0.00817884, 0.00406413, 0.00776921, 0.0256932, 0.03652405, 0.24328028,
{0.22757064, 0.00868076, 0.03277484, 0.00773243, 0.0055188,
0.0178066, 0.05756383, 0.04745709, 0.11716303, 0.01392945,
0.00657097, 0.10587974, 0.01506727, 0.01077484, 0.00817884,
0.00406413, 0.00776921, 0.0256932, 0.03652405, 0.24328028,
0.06217413, 0.02201481, 0.00689594, 0.01944171, 0.09399488, 0.02028993, 0.00738604,
0.03026811, 0.00960958, 0.08420492, 0.10914991, 0.05364435, 0.07748005, 0.02937079,
0.0095028, 0.00419387, 0.03018442, 0.01143929, 0.2511977, 0.06755678,
0.06217413, 0.02201481, 0.00689594, 0.01944171, 0.09399488,
0.02028993, 0.00738604, 0.03026811, 0.00960958, 0.08420492,
0.10914991, 0.05364435, 0.07748005, 0.02937079, 0.0095028,
0.00419387, 0.03018442, 0.01143929, 0.2511977, 0.06755678,
0.00587593, 0.04548053, 0.0346656, 0.02252594, 0.03742775, 0.04219705, 0.00173478,
0.00571623, 0.2818174, 0.06557446, 0.00763582, 0.03321466, 0.01962049, 0.03930537,
0.01254255, 0.01246025, 0.02449929, 0.02939305, 0.15432668, 0.12398617});
0.00587593, 0.04548053, 0.0346656, 0.02252594, 0.03742775,
0.04219705, 0.00173478, 0.00571623, 0.2818174, 0.06557446,
0.00763582, 0.03321466, 0.01962049, 0.03930537, 0.01254255,
0.01246025, 0.02449929, 0.02939305, 0.15432668, 0.12398617});
// clang-format on
test_case.run(4);
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_invalid_axis_1D) {
ASSERT_THROW(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_invalid_axis_1D.onnx")),
ngraph::ngraph_error)
<< "Softmax model with invalid axis was successfully imported while it should have thrown.";
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_1_opset11) {
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_1_opset11.onnx"));
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>(SOFTMAX_INPUT);
// clang-format off
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.88890495, 0.04825497, 0.27088348, 0.04490523, 0.02037154,
0.06955369, 0.31998834, 0.39223197, 0.68041159, 0.05141776,
0.02566661, 0.5885689, 0.12453075, 0.06257374, 0.03019055,
0.01587475, 0.0431878, 0.21235381, 0.21210944, 0.89802015,
0.31752626, 0.19442629, 0.0546935, 0.06279221, 0.36823282,
0.10362164, 0.06523066, 0.24006419, 0.03103672, 0.32987983,
0.55743381, 0.473766, 0.61451431, 0.09486084, 0.03722801,
0.02141829, 0.26657706, 0.090728, 0.81131024, 0.26465935,
0.08619648, 0.43343993, 0.3877785, 0.04523505, 0.15625437,
0.61900597, 0.01653285, 0.06394322, 0.56592636, 0.27376196,
0.11201305, 0.31654337, 0.21947994, 0.07893034, 0.05236297,
0.18278451, 0.23348385, 0.32879834, 0.30990825, 0.5176207});
// clang-format on
test_case.run(4);
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_invalid_axis_3D) {
ASSERT_THROW(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_invalid_axis_3D.onnx")),
ngraph::ngraph_error)
<< "Softmax model with invalid axis was successfully imported while it should have thrown.";
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_negative_1_opset11) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_negative_1_opset11.onnx"));
auto test_case = test::TestCase(function);
test_case.add_input<float>(SOFTMAX_INPUT);
// clang-format off
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.88890495, 0.04825497, 0.27088348, 0.04490523, 0.02037154,
0.06955369, 0.31998834, 0.39223197, 0.68041159, 0.05141776,
0.02566661, 0.5885689, 0.12453075, 0.06257374, 0.03019055,
0.01587475, 0.0431878, 0.21235381, 0.21210944, 0.89802015,
0.31752626, 0.19442629, 0.0546935, 0.06279221, 0.36823282,
0.10362164, 0.06523066, 0.24006419, 0.03103672, 0.32987983,
0.55743381, 0.473766, 0.61451431, 0.09486084, 0.03722801,
0.02141829, 0.26657706, 0.090728, 0.81131024, 0.26465935,
0.08619648, 0.43343993, 0.3877785, 0.04523505, 0.15625437,
0.61900597, 0.01653285, 0.06394322, 0.56592636, 0.27376196,
0.11201305, 0.31654337, 0.21947994, 0.07893034, 0.05236297,
0.18278451, 0.23348385, 0.32879834, 0.30990825, 0.5176207});
// clang-format on
test_case.run(6);
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_negative_1_opset13) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_negative_1_opset13.onnx"));
auto test_case = test::TestCase(function);
test_case.add_input<float>(SOFTMAX_INPUT);
// clang-format off
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.88890495, 0.04825497, 0.27088348, 0.04490523, 0.02037154,
0.06955369, 0.31998834, 0.39223197, 0.68041159, 0.05141776,
0.02566661, 0.5885689, 0.12453075, 0.06257374, 0.03019055,
0.01587475, 0.0431878, 0.21235381, 0.21210944, 0.89802015,
0.31752626, 0.19442629, 0.0546935, 0.06279221, 0.36823282,
0.10362164, 0.06523066, 0.24006419, 0.03103672, 0.32987983,
0.55743381, 0.473766, 0.61451431, 0.09486084, 0.03722801,
0.02141829, 0.26657706, 0.090728, 0.81131024, 0.26465935,
0.08619648, 0.43343993, 0.3877785, 0.04523505, 0.15625437,
0.61900597, 0.01653285, 0.06394322, 0.56592636, 0.27376196,
0.11201305, 0.31654337, 0.21947994, 0.07893034, 0.05236297,
0.18278451, 0.23348385, 0.32879834, 0.30990825, 0.5176207});
// clang-format on
test_case.run(6);
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_sub) {

View File

@ -37,17 +37,8 @@ OutputVector softmax(const Node& node) {
result = default_opset::Constant::create(data.get_element_type(), Shape{}, {1});
break;
}
case 1: {
// checks if the axis belongs to the allowed values set (-1 and 0 for 1D)
ngraph::normalize_axis(node.get_description(), axis, data.get_partial_shape().rank());
result = std::make_shared<default_opset::Softmax>(data, 0);
break;
}
default: {
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data.get_partial_shape().rank());
result = onnx_softmax(data, normalized_axis);
result = onnx_softmax(data, axis);
break;
}
}
@ -69,17 +60,8 @@ OutputVector softmax(const Node& node) {
result = default_opset::Constant::create(data.get_element_type(), Shape{}, {1});
break;
}
case 1: {
// checks if the axis belongs to the allowed values set (-1 and 0 for 1D)
ngraph::normalize_axis(node.get_description(), axis, data.get_partial_shape().rank());
result = std::make_shared<default_opset::Softmax>(data, 0);
break;
}
default: {
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data.get_partial_shape().rank());
result = std::make_shared<default_opset::Softmax>(data, normalized_axis);
result = std::make_shared<ov::op::v8::Softmax>(data, axis);
break;
}
}
@ -92,9 +74,8 @@ OutputVector softmax(const Node& node) {
const auto data = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<int64_t>("axis", -1);
const auto normalized_axis = ngraph::normalize_axis(node.get_description(), axis, data.get_partial_shape().rank());
return {std::make_shared<default_opset::Softmax>(data, normalized_axis)};
return {std::make_shared<ov::op::v8::Softmax>(data, axis)};
}
} // namespace set_13
} // namespace op