Avoid unnecessary Reshape in ONNX Softmax impl (#2686)
This commit is contained in:
parent
0dde02e44f
commit
5eee1ea925
@ -33,16 +33,11 @@ namespace ngraph
|
|||||||
const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
|
const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
|
||||||
|
|
||||||
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||||
const auto max = std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1);
|
const auto max =
|
||||||
|
std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
|
||||||
// equivalent to numpy's max.reshape((-1,1))
|
|
||||||
const auto reshape_pattern =
|
|
||||||
default_opset::Constant::create(element::i64, Shape{2}, {0, 1});
|
|
||||||
const auto reshaped_max =
|
|
||||||
std::make_shared<default_opset::Reshape>(max, reshape_pattern, true);
|
|
||||||
|
|
||||||
const auto data_minus_max =
|
const auto data_minus_max =
|
||||||
std::make_shared<default_opset::Subtract>(coerced_data, reshaped_max);
|
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
||||||
|
|
||||||
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
|
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
|
||||||
if (data.get_partial_shape().is_static())
|
if (data.get_partial_shape().is_static())
|
||||||
|
Loading…
Reference in New Issue
Block a user