[ONNX] don't hardcode shape in Softmax operator (#3676)

* [ONNX] don't hardcode shape in Softmax operator

* use named constant for special zero param in reshape
This commit is contained in:
Mateusz Tabaka 2020-12-25 05:41:46 +01:00 committed by GitHub
parent cc019e0a11
commit 1aee9f9ffe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 10 deletions

View File

@ -40,15 +40,9 @@ namespace ngraph
std::make_shared<default_opset::Subtract>(coerced_data, max);
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
if (data.get_partial_shape().is_static())
{
return ngraph::builder::opset1::reshape(result, data.get_shape());
}
else
{
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
}
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
const bool special_zero = false;
return std::make_shared<default_opset::Reshape>(result, data_shape, special_zero);
}
}

View File

@ -806,6 +806,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_0)
test_case.add_input<float>(SOFTMAX_INPUT);
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.09683057, 0.00369363, 0.01394559, 0.00329012, 0.00234823, 0.00757665, 0.02449322,
0.02019284, 0.04985249, 0.00592694, 0.00279593, 0.04505148, 0.00641108, 0.00458466,
0.00348007, 0.00172928, 0.00330577, 0.01093237, 0.01554086, 0.10351497,
@ -826,10 +827,11 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_1)
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_1.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
test_case.add_input<float>(SOFTMAX_INPUT);
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.22757064, 0.00868076, 0.03277484, 0.00773243, 0.0055188, 0.0178066, 0.05756383,
0.04745709, 0.11716303, 0.01392945, 0.00657097, 0.10587974, 0.01506727, 0.01077484,
0.00817884, 0.00406413, 0.00776921, 0.0256932, 0.03652405, 0.24328028,

View File

@ -1129,6 +1129,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_model_softmax_axis_2)
test_case.add_input<float>(input);
test_case.add_expected_output<float>(
Shape{3, 4, 5},
{0.80619486, 0.03075257, 0.1161086, 0.027393, 0.01955098, 0.07012682, 0.22670066,
0.18689779, 0.4614171, 0.05485763, 0.04486172, 0.72286838, 0.10286818, 0.07356265,
0.05583908, 0.01280724, 0.02448298, 0.08096658, 0.11509768, 0.76664552,