diff --git a/ngraph/frontend/onnx_import/src/op/softmax.cpp b/ngraph/frontend/onnx_import/src/op/softmax.cpp index b0c4fe4081c..87c7e5192f7 100644 --- a/ngraph/frontend/onnx_import/src/op/softmax.cpp +++ b/ngraph/frontend/onnx_import/src/op/softmax.cpp @@ -33,16 +33,11 @@ namespace ngraph const auto coerced_data = ngraph::builder::opset1::flatten(data, axis); const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1}); - const auto max = std::make_shared(coerced_data, axis_1); - - // equivalent to numpy's max.reshape((-1,1)) - const auto reshape_pattern = - default_opset::Constant::create(element::i64, Shape{2}, {0, 1}); - const auto reshaped_max = - std::make_shared(max, reshape_pattern, true); + const auto max = + std::make_shared(coerced_data, axis_1, true); const auto data_minus_max = - std::make_shared(coerced_data, reshaped_max); + std::make_shared(coerced_data, max); const auto result = std::make_shared(data_minus_max, 1); if (data.get_partial_shape().is_static())