diff --git a/ngraph/frontend/onnx_import/src/op/softmax.cpp b/ngraph/frontend/onnx_import/src/op/softmax.cpp index 87c7e5192f7..1e632c76fe3 100644 --- a/ngraph/frontend/onnx_import/src/op/softmax.cpp +++ b/ngraph/frontend/onnx_import/src/op/softmax.cpp @@ -40,15 +40,9 @@ namespace ngraph std::make_shared(coerced_data, max); const auto result = std::make_shared(data_minus_max, 1); - if (data.get_partial_shape().is_static()) - { - return ngraph::builder::opset1::reshape(result, data.get_shape()); - } - else - { - const auto data_shape = std::make_shared(data); - return std::make_shared(result, data_shape, false); - } + const auto data_shape = std::make_shared(data); + const bool special_zero = false; + return std::make_shared(result, data_shape, special_zero); } } diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 56827d6f7c6..a3bf395b26b 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -806,6 +806,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_0) test_case.add_input(SOFTMAX_INPUT); test_case.add_expected_output( + Shape{3, 4, 5}, {0.09683057, 0.00369363, 0.01394559, 0.00329012, 0.00234823, 0.00757665, 0.02449322, 0.02019284, 0.04985249, 0.00592694, 0.00279593, 0.04505148, 0.00641108, 0.00458466, 0.00348007, 0.00172928, 0.00330577, 0.01093237, 0.01554086, 0.10351497, @@ -826,10 +827,11 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_1) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_1.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = test::TestCase(function); test_case.add_input(SOFTMAX_INPUT); test_case.add_expected_output( + Shape{3, 4, 5}, {0.22757064, 0.00868076, 0.03277484, 0.00773243, 0.0055188, 0.0178066, 0.05756383, 0.04745709, 0.11716303, 0.01392945, 0.00657097, 0.10587974, 0.01506727, 0.01077484, 0.00817884, 0.00406413, 0.00776921, 0.0256932, 0.03652405, 0.24328028, diff --git a/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp b/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp index a922a6361d7..1facfdde274 100644 --- a/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp +++ b/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp @@ -1129,6 +1129,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_model_softmax_axis_2) test_case.add_input(input); test_case.add_expected_output( + Shape{3, 4, 5}, {0.80619486, 0.03075257, 0.1161086, 0.027393, 0.01955098, 0.07012682, 0.22670066, 0.18689779, 0.4614171, 0.05485763, 0.04486172, 0.72286838, 0.10286818, 0.07356265, 0.05583908, 0.01280724, 0.02448298, 0.08096658, 0.11509768, 0.76664552,