Removed hardcoded shapes from LogSoftmax operation (#4475)
* Removed hardcoded shapes from LogSoftmax operation * Added tests * Fixed comments
This commit is contained in:
parent
c38e7a2986
commit
1458ba392e
@ -40,15 +40,8 @@ namespace ngraph
|
||||
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
||||
|
||||
const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
|
||||
if (data.get_partial_shape().is_static())
|
||||
{
|
||||
return ngraph::builder::opset1::reshape(result, data.get_shape());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
|
||||
}
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
|
||||
}
|
||||
|
||||
OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)
|
||||
|
@ -444,16 +444,13 @@ if (MSVC)
|
||||
target_compile_options(unit-test PRIVATE "/bigobj")
|
||||
endif()
|
||||
|
||||
if (TARGET inference_engine)
|
||||
target_link_libraries(unit-test PRIVATE inference_engine)
|
||||
endif()
|
||||
target_link_libraries(unit-test PRIVATE inference_engine)
|
||||
target_link_libraries(unit-test PRIVATE ie_backend)
|
||||
|
||||
if (NGRAPH_ONNX_IMPORT_ENABLE)
|
||||
target_link_libraries(unit-test PRIVATE onnx_importer)
|
||||
endif()
|
||||
|
||||
target_link_libraries(unit-test PRIVATE ie_backend)
|
||||
|
||||
if (NGRAPH_INTERPRETER_ENABLE)
|
||||
target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE)
|
||||
target_link_libraries(unit-test PRIVATE interpreter_backend)
|
||||
|
@ -50,6 +50,7 @@
|
||||
#include "util/engine/test_engines.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
@ -3546,6 +3547,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D)
|
||||
test_case.run_with_tolerance_as_fp();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D_reshape)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt"));
|
||||
InferenceEngine::CNNNetwork net(function);
|
||||
InferenceEngine::ICNNNetwork::InputShapes shapes = {};
|
||||
InferenceEngine::SizeVector shape = {1, 1, 4000};
|
||||
shapes[net.getInputsInfo().begin()->first] = shape;
|
||||
EXPECT_NO_THROW(net.reshape(shapes));
|
||||
ASSERT_EQ(shape, net.getOutputsInfo().begin()->second->getDims());
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
|
Loading…
Reference in New Issue
Block a user