Removed hardcoded shapes from LogSoftmax operation (#4475)
* Removed hardcoded shapes from LogSoftmax operation * Added tests * Fixed comments
This commit is contained in:
parent
c38e7a2986
commit
1458ba392e
@ -40,15 +40,8 @@ namespace ngraph
|
|||||||
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
std::make_shared<default_opset::Subtract>(coerced_data, max);
|
||||||
|
|
||||||
const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
|
const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
|
||||||
if (data.get_partial_shape().is_static())
|
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||||
{
|
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
|
||||||
return ngraph::builder::opset1::reshape(result, data.get_shape());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
|
|
||||||
return std::make_shared<default_opset::Reshape>(result, data_shape, false);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)
|
OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)
|
||||||
|
@ -444,16 +444,13 @@ if (MSVC)
|
|||||||
target_compile_options(unit-test PRIVATE "/bigobj")
|
target_compile_options(unit-test PRIVATE "/bigobj")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (TARGET inference_engine)
|
target_link_libraries(unit-test PRIVATE inference_engine)
|
||||||
target_link_libraries(unit-test PRIVATE inference_engine)
|
target_link_libraries(unit-test PRIVATE ie_backend)
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NGRAPH_ONNX_IMPORT_ENABLE)
|
if (NGRAPH_ONNX_IMPORT_ENABLE)
|
||||||
target_link_libraries(unit-test PRIVATE onnx_importer)
|
target_link_libraries(unit-test PRIVATE onnx_importer)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(unit-test PRIVATE ie_backend)
|
|
||||||
|
|
||||||
if (NGRAPH_INTERPRETER_ENABLE)
|
if (NGRAPH_INTERPRETER_ENABLE)
|
||||||
target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE)
|
target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE)
|
||||||
target_link_libraries(unit-test PRIVATE interpreter_backend)
|
target_link_libraries(unit-test PRIVATE interpreter_backend)
|
||||||
|
@ -50,6 +50,7 @@
|
|||||||
#include "util/engine/test_engines.hpp"
|
#include "util/engine/test_engines.hpp"
|
||||||
#include "util/test_tools.hpp"
|
#include "util/test_tools.hpp"
|
||||||
#include "util/type_prop.hpp"
|
#include "util/type_prop.hpp"
|
||||||
|
#include <cpp/ie_cnn_network.h>
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
|
|
||||||
@ -3546,6 +3547,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D)
|
|||||||
test_case.run_with_tolerance_as_fp();
|
test_case.run_with_tolerance_as_fp();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D_reshape)
|
||||||
|
{
|
||||||
|
const auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt"));
|
||||||
|
InferenceEngine::CNNNetwork net(function);
|
||||||
|
InferenceEngine::ICNNNetwork::InputShapes shapes = {};
|
||||||
|
InferenceEngine::SizeVector shape = {1, 1, 4000};
|
||||||
|
shapes[net.getInputsInfo().begin()->first] = shape;
|
||||||
|
EXPECT_NO_THROW(net.reshape(shapes));
|
||||||
|
ASSERT_EQ(shape, net.getOutputsInfo().begin()->second->getDims());
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid)
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid)
|
||||||
{
|
{
|
||||||
auto function = onnx_import::import_onnx_model(
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
Loading…
Reference in New Issue
Block a user