diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index b34e599b9d4..89df86b7ed0 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -769,6 +769,20 @@ std::shared_ptr convertFunctionToICNNNetwork(const std::shared_p cnnLayer->insData.resize(inputCount); for (size_t i = 0; i < layer->get_output_size(); i++) { + // Memory node with index = 1 has no inputs according to the specification. + // For proper conversion, we must cut off all the layers and data nodes above ReadValue, + // if they are connected only with this layer. + // Now MO generates only constants or constant sub-graphs as input to ReadValue op. + if (std::dynamic_pointer_cast<::ngraph::op::Constant>(layer)) { + bool all_to_read_value = !layer->output(i).get_target_inputs().empty(); + for (const auto &output_input : layer->output(i).get_target_inputs()) { + all_to_read_value + &= dynamic_cast(output_input.get_node()) != nullptr; + } + if (all_to_read_value) + continue; + } + if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") { cnnLayer->outData.clear(); continue; @@ -776,7 +790,6 @@ std::shared_ptr convertFunctionToICNNNetwork(const std::shared_p std::string outName = layer->get_friendly_name(); if (layer->get_output_size() != 1) outName += "." + std::to_string(i); DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str()); - SizeVector dims; dims = layer->get_output_shape(i); for (const auto &dim : dims) { diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp index 3ef01377b79..08414493378 100644 --- a/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -677,4 +678,25 @@ TEST(CNNNGraphImplTests, TestCheckStats) { InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph); } +TEST(CNNNGraphImplTests, CanSetBatchReadValue) { + std::shared_ptr ngraph; + { + auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2}); + auto constant = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2}, + std::vector{1, 2}); + + auto read_value = std::make_shared(constant, "variable_id"); + auto add = std::make_shared(input, read_value); + auto result = std::make_shared(add); + + ngraph::ParameterVector params = {input}; + ngraph::ResultVector results = {result}; + + ngraph = std::make_shared(results, params); + } + + InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph); + auto status = cnnNet.getCNNNetwork()->setBatchSize(4, nullptr); + EXPECT_EQ(status, StatusCode::OK); +} IE_SUPPRESS_DEPRECATED_END