Avoid unnecessary Reshape in ONNX Softmax impl (#2686)

This commit is contained in:
Tomasz Dołbniak 2020-10-16 11:30:20 +02:00 committed by GitHub
parent 0dde02e44f
commit 5eee1ea925
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -33,16 +33,11 @@ namespace ngraph
const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
const auto max = std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1);
// equivalent to numpy's max.reshape((-1,1))
const auto reshape_pattern =
default_opset::Constant::create(element::i64, Shape{2}, {0, 1});
const auto reshaped_max =
std::make_shared<default_opset::Reshape>(max, reshape_pattern, true);
const auto max =
std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
const auto data_minus_max =
std::make_shared<default_opset::Subtract>(coerced_data, reshaped_max);
std::make_shared<default_opset::Subtract>(coerced_data, max);
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
if (data.get_partial_shape().is_static())