Fix for Kaldi models with a batch of more than 1 (#1012)

* Fix kaldi models (batch > 1)

* ngraph codestyle

* fix ngraph to ie conversion

* Added comment

* apply review comments

* Added test for the case using the SetBatchSize function when ReadValue op is in the network

* Check status code instead of message

* Use new ngraph api
This commit is contained in:
Ivan Tikhonov 2020-06-23 08:22:12 +03:00 committed by GitHub
parent b5be90a886
commit 3490b985dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 1 deletions

View File

@ -769,6 +769,20 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
cnnLayer->insData.resize(inputCount); cnnLayer->insData.resize(inputCount);
for (size_t i = 0; i < layer->get_output_size(); i++) { for (size_t i = 0; i < layer->get_output_size(); i++) {
// Memory node with index = 1 has no inputs according to the specification.
// For proper conversion, we must cut off all the layers and data nodes above ReadValue,
// if they are connected only with this layer.
// Now MO generates only constants or constant sub-graphs as input to ReadValue op.
if (std::dynamic_pointer_cast<::ngraph::op::Constant>(layer)) {
bool all_to_read_value = !layer->output(i).get_target_inputs().empty();
for (const auto &output_input : layer->output(i).get_target_inputs()) {
all_to_read_value
&= dynamic_cast<ngraph::op::ReadValue *>(output_input.get_node()) != nullptr;
}
if (all_to_read_value)
continue;
}
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") { if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
cnnLayer->outData.clear(); cnnLayer->outData.clear();
continue; continue;
@ -776,7 +790,6 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
std::string outName = layer->get_friendly_name(); std::string outName = layer->get_friendly_name();
if (layer->get_output_size() != 1) outName += "." + std::to_string(i); if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str()); DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
SizeVector dims; SizeVector dims;
dims = layer->get_output_shape(i); dims = layer->get_output_shape(i);
for (const auto &dim : dims) { for (const auto &dim : dims) {

View File

@ -17,6 +17,7 @@
#include <ie_core.hpp> #include <ie_core.hpp>
#include <net_pass.h> #include <net_pass.h>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/function.hpp> #include <ngraph/function.hpp>
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
#include <ngraph/op/maximum.hpp> #include <ngraph/op/maximum.hpp>
@ -677,4 +678,25 @@ TEST(CNNNGraphImplTests, TestCheckStats) {
InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph); InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
} }
TEST(CNNNGraphImplTests, CanSetBatchReadValue) {
std::shared_ptr<ngraph::Function> ngraph;
{
auto input = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2});
auto constant = std::make_shared<ngraph::opset3::Constant>(ngraph::element::f32, ngraph::Shape{1, 2},
std::vector<float>{1, 2});
auto read_value = std::make_shared<ngraph::opset3::ReadValue>(constant, "variable_id");
auto add = std::make_shared<ngraph::opset3::Add>(input, read_value);
auto result = std::make_shared<ngraph::op::Result>(add);
ngraph::ParameterVector params = {input};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<ngraph::Function>(results, params);
}
InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
auto status = cnnNet.getCNNNetwork()->setBatchSize(4, nullptr);
EXPECT_EQ(status, StatusCode::OK);
}
IE_SUPPRESS_DEPRECATED_END IE_SUPPRESS_DEPRECATED_END