Avoid unnecessary Reshape in ONNX Softmax impl (#2686)
This commit is contained in:
parent
0dde02e44f
commit
5eee1ea925
@ -33,16 +33,11 @@ namespace ngraph
|
||||
const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
|
||||
|
||||
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||
const auto max = std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1);
|
||||
|
||||
// equivalent to numpy's max.reshape((-1,1))
|
||||
const auto reshape_pattern =
|
||||
default_opset::Constant::create(element::i64, Shape{2}, {0, 1});
|
||||
const auto reshaped_max =
|
||||
std::make_shared<default_opset::Reshape>(max, reshape_pattern, true);
|
||||
const auto max =
|
||||
std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
|
||||
|
||||
const auto data_minus_max =
|
||||
std::make_shared<default_opset::Subtract>(coerced_data, reshaped_max);
|
||||
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
||||
|
||||
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
|
||||
if (data.get_partial_shape().is_static())
|
||||
|
Loading…
Reference in New Issue
Block a user