[ONNX] don't hardcode shape in Softmax operator (#3676)
* [ONNX] don't hardcode shape in Softmax operator * use named constant for special zero param in reshape
This commit is contained in:
parent
cc019e0a11
commit
1aee9f9ffe
@ -40,15 +40,9 @@ namespace ngraph
|
||||
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
||||
|
||||
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
|
||||
if (data.get_partial_shape().is_static())
|
||||
{
|
||||
return ngraph::builder::opset1::reshape(result, data.get_shape());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
|
||||
}
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
const bool special_zero = false;
|
||||
return std::make_shared<default_opset::Reshape>(result, data_shape, special_zero);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -806,6 +806,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_0)
|
||||
test_case.add_input<float>(SOFTMAX_INPUT);
|
||||
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{3, 4, 5},
|
||||
{0.09683057, 0.00369363, 0.01394559, 0.00329012, 0.00234823, 0.00757665, 0.02449322,
|
||||
0.02019284, 0.04985249, 0.00592694, 0.00279593, 0.04505148, 0.00641108, 0.00458466,
|
||||
0.00348007, 0.00172928, 0.00330577, 0.01093237, 0.01554086, 0.10351497,
|
||||
@ -826,10 +827,11 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_axis_1)
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_axis_1.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||
test_case.add_input<float>(SOFTMAX_INPUT);
|
||||
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{3, 4, 5},
|
||||
{0.22757064, 0.00868076, 0.03277484, 0.00773243, 0.0055188, 0.0178066, 0.05756383,
|
||||
0.04745709, 0.11716303, 0.01392945, 0.00657097, 0.10587974, 0.01506727, 0.01077484,
|
||||
0.00817884, 0.00406413, 0.00776921, 0.0256932, 0.03652405, 0.24328028,
|
||||
|
@ -1129,6 +1129,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_model_softmax_axis_2)
|
||||
test_case.add_input<float>(input);
|
||||
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{3, 4, 5},
|
||||
{0.80619486, 0.03075257, 0.1161086, 0.027393, 0.01955098, 0.07012682, 0.22670066,
|
||||
0.18689779, 0.4614171, 0.05485763, 0.04486172, 0.72286838, 0.10286818, 0.07356265,
|
||||
0.05583908, 0.01280724, 0.02448298, 0.08096658, 0.11509768, 0.76664552,
|
||||
|
Loading…
Reference in New Issue
Block a user