Added dynamic check for convertFunctionToCNNNetwork functoin (#2797)

* Keep changes

* Added dynamic check for convertFunctionToCNNNetwork

* Fixed test
This commit is contained in:
Gleb Kazantaev 2020-10-23 18:17:26 +03:00 committed by GitHub
parent 0802c40527
commit 33371ca1ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 0 deletions

View File

@ -907,6 +907,30 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
network->setInputInfo(info);
};
// Check if some of function nodes has dynamic input or output shape
// we collect this nodes and then throw an exception with the list
// of dynamic nodes.
std::stringstream err_log;
for (const auto & node : graph->get_ordered_ops()) {
bool is_dynamic = false;
for (const auto & input : node->inputs()) {
if (input.get_partial_shape().is_dynamic()) {
is_dynamic = true;
break;
}
}
for (const auto & output : node->outputs()) {
if (output.get_partial_shape().is_dynamic()) {
is_dynamic = true;
break;
}
}
if (is_dynamic) err_log << node << std::endl;
}
if (!err_log.str().empty()) {
THROW_IE_EXCEPTION << "\nUnsupported dynamic ops: \n" << err_log.str();
}
const CNNNetworkNGraphImpl* nGraphImpl = dynamic_cast<const CNNNetworkNGraphImpl*>(&network);
InputsDataMap thisInputDataMap;

View File

@ -3,6 +3,7 @@
//
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <cpp/ie_cnn_network.h>
#include <legacy/cnn_network_impl.hpp> // deprecated API
@ -206,4 +207,33 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
const std::string resp_msg = err.what();
ASSERT_TRUE(resp_msg.find(ref_msg) != std::string::npos) << resp_msg;
}
}
TEST(ConvertFunctionToCNNNetworkTests, UnsupportedDynamicOps) {
std::shared_ptr<ngraph::Function> f;
{
auto param = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
param->set_friendly_name("param");
auto relu = std::make_shared<ngraph::opset4::Relu>(param);
relu->set_friendly_name("relu");
auto non_zero = std::make_shared<ngraph::opset4::NonZero>(relu);
non_zero->set_friendly_name("non_zero");
auto result = std::make_shared<ngraph::op::Result>(non_zero->output(0));
result->set_friendly_name("result");
f = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{param});
}
InferenceEngine::CNNNetwork nGraphImpl(f);
try {
InferenceEngine::details::convertFunctionToICNNNetwork(f, nGraphImpl);
FAIL() << "InferenceEngineException must be thrown";
} catch(InferenceEngine::details::InferenceEngineException & e) {
EXPECT_THAT(e.what(), testing::HasSubstr(std::string("Unsupported dynamic ops: \n"
"v0::Parameter param () -> (f32?)\n"
"v0::Relu relu (param[0]:f32?) -> (f32?)\n"
"v3::NonZero non_zero (relu[0]:f32?) -> (i64{?,?})\n"
"v0::Result result (non_zero[0]:i64{?,?}) -> (i64{?,?})")));
}
}