Removed hardcoded shapes from LogSoftmax operation (#4475)

* Removed hardcoded shapes from LogSoftmax operation

* Added tests

* Fixed comments
This commit is contained in:
Ilya Churaev 2021-02-25 20:02:14 +03:00 committed by GitHub
parent c38e7a2986
commit 1458ba392e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 14 deletions

View File

@ -40,15 +40,8 @@ namespace ngraph
std::make_shared<default_opset::Subtract>(coerced_data, max);
const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
if (data.get_partial_shape().is_static())
{
return ngraph::builder::opset1::reshape(result, data.get_shape());
}
else
{
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
}
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
}
OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)

View File

@ -444,16 +444,13 @@ if (MSVC)
target_compile_options(unit-test PRIVATE "/bigobj")
endif()
if (TARGET inference_engine)
target_link_libraries(unit-test PRIVATE inference_engine)
endif()
target_link_libraries(unit-test PRIVATE inference_engine)
target_link_libraries(unit-test PRIVATE ie_backend)
if (NGRAPH_ONNX_IMPORT_ENABLE)
target_link_libraries(unit-test PRIVATE onnx_importer)
endif()
target_link_libraries(unit-test PRIVATE ie_backend)
if (NGRAPH_INTERPRETER_ENABLE)
target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE)
target_link_libraries(unit-test PRIVATE interpreter_backend)

View File

@ -50,6 +50,7 @@
#include "util/engine/test_engines.hpp"
#include "util/test_tools.hpp"
#include "util/type_prop.hpp"
#include <cpp/ie_cnn_network.h>
NGRAPH_SUPPRESS_DEPRECATED_START
@ -3546,6 +3547,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D)
test_case.run_with_tolerance_as_fp();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D_reshape)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt"));
InferenceEngine::CNNNetwork net(function);
InferenceEngine::ICNNNetwork::InputShapes shapes = {};
InferenceEngine::SizeVector shape = {1, 1, 4000};
shapes[net.getInputsInfo().begin()->first] = shape;
EXPECT_NO_THROW(net.reshape(shapes));
ASSERT_EQ(shape, net.getOutputsInfo().begin()->second->getDims());
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid)
{
auto function = onnx_import::import_onnx_model(