diff --git a/ngraph/frontend/onnx_import/src/op/log_softmax.cpp b/ngraph/frontend/onnx_import/src/op/log_softmax.cpp index 07c71a5663b..31745da1a49 100644 --- a/ngraph/frontend/onnx_import/src/op/log_softmax.cpp +++ b/ngraph/frontend/onnx_import/src/op/log_softmax.cpp @@ -40,15 +40,8 @@ namespace ngraph std::make_shared(coerced_data, max); const auto result = std::make_shared(data_minus_max, 1); - if (data.get_partial_shape().is_static()) - { - return ngraph::builder::opset1::reshape(result, data.get_shape()); - } - else - { - const auto data_shape = std::make_shared(data); - return std::make_shared(result, data_shape, false); - } + const auto data_shape = std::make_shared(data); + return std::make_shared(result, data_shape, false); } OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS) diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 799355a165c..1a6bcfbb435 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -444,16 +444,13 @@ if (MSVC) target_compile_options(unit-test PRIVATE "/bigobj") endif() -if (TARGET inference_engine) - target_link_libraries(unit-test PRIVATE inference_engine) -endif() +target_link_libraries(unit-test PRIVATE inference_engine) +target_link_libraries(unit-test PRIVATE ie_backend) if (NGRAPH_ONNX_IMPORT_ENABLE) target_link_libraries(unit-test PRIVATE onnx_importer) endif() -target_link_libraries(unit-test PRIVATE ie_backend) - if (NGRAPH_INTERPRETER_ENABLE) target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE) target_link_libraries(unit-test PRIVATE interpreter_backend) diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 4189685bb58..619e24e5743 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -50,6 +50,7 @@ #include "util/engine/test_engines.hpp" #include "util/test_tools.hpp" #include "util/type_prop.hpp" +#include NGRAPH_SUPPRESS_DEPRECATED_START @@ -3546,6 +3547,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D) test_case.run_with_tolerance_as_fp(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D_reshape) +{ + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt")); + InferenceEngine::CNNNetwork net(function); + InferenceEngine::ICNNNetwork::InputShapes shapes = {}; + InferenceEngine::SizeVector shape = {1, 1, 4000}; + shapes[net.getInputsInfo().begin()->first] = shape; + EXPECT_NO_THROW(net.reshape(shapes)); + ASSERT_EQ(shape, net.getOutputsInfo().begin()->second->getDims()); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid) { auto function = onnx_import::import_onnx_model(