[ONNX] remove unnecessary ReduceMax and Subtract in Softmax op (#3717)

This commit is contained in:
Mateusz Tabaka 2021-01-07 19:02:59 +01:00 committed by GitHub
parent 7be7a8fb30
commit b56cf07f0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 10 deletions

View File

@ -31,15 +31,7 @@ namespace ngraph
const int64_t axis) const int64_t axis)
{ {
const auto coerced_data = ngraph::builder::opset1::flatten(data, axis); const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
const auto result = std::make_shared<default_opset::Softmax>(coerced_data, 1);
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
const auto max =
std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
const auto data_minus_max =
std::make_shared<default_opset::Subtract>(coerced_data, max);
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data); const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
const bool special_zero = false; const bool special_zero = false;
return std::make_shared<default_opset::Reshape>(result, data_shape, special_zero); return std::make_shared<default_opset::Reshape>(result, data_shape, special_zero);

View File

@ -1142,7 +1142,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_model_softmax_axis_2)
0.01439711, 0.70979614, 0.16515835, 0.06798343, 0.2957175, 0.17468555, 0.34994439, 0.01439711, 0.70979614, 0.16515835, 0.06798343, 0.2957175, 0.17468555, 0.34994439,
0.11166912, 0.03615172, 0.07108136, 0.08527994, 0.44775794, 0.35972905}); 0.11166912, 0.03615172, 0.07108136, 0.08527994, 0.44775794, 0.35972905});
test_case.run(4); test_case.run(3);
} }
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_range_positive_step) NGRAPH_TEST(${BACKEND_NAME}, onnx_model_range_positive_step)