Fix for Kaldi models with a batch of more than 1 (#1012)
* Fix kaldi models (batch > 1) * ngraph codestyle * fix ngraph to ie conversion * Added comment * apply review comments * Added test for the case using the SetBatchSize function when ReadValue op is in the network * Check status code instead of message * Use new ngraph api
This commit is contained in:
parent
b5be90a886
commit
3490b985dd
@ -769,6 +769,20 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
|||||||
cnnLayer->insData.resize(inputCount);
|
cnnLayer->insData.resize(inputCount);
|
||||||
|
|
||||||
for (size_t i = 0; i < layer->get_output_size(); i++) {
|
for (size_t i = 0; i < layer->get_output_size(); i++) {
|
||||||
|
// Memory node with index = 1 has no inputs according to the specification.
|
||||||
|
// For proper conversion, we must cut off all the layers and data nodes above ReadValue,
|
||||||
|
// if they are connected only with this layer.
|
||||||
|
// Now MO generates only constants or constant sub-graphs as input to ReadValue op.
|
||||||
|
if (std::dynamic_pointer_cast<::ngraph::op::Constant>(layer)) {
|
||||||
|
bool all_to_read_value = !layer->output(i).get_target_inputs().empty();
|
||||||
|
for (const auto &output_input : layer->output(i).get_target_inputs()) {
|
||||||
|
all_to_read_value
|
||||||
|
&= dynamic_cast<ngraph::op::ReadValue *>(output_input.get_node()) != nullptr;
|
||||||
|
}
|
||||||
|
if (all_to_read_value)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
|
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
|
||||||
cnnLayer->outData.clear();
|
cnnLayer->outData.clear();
|
||||||
continue;
|
continue;
|
||||||
@ -776,7 +790,6 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
|||||||
std::string outName = layer->get_friendly_name();
|
std::string outName = layer->get_friendly_name();
|
||||||
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
|
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
|
||||||
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
|
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
|
||||||
|
|
||||||
SizeVector dims;
|
SizeVector dims;
|
||||||
dims = layer->get_output_shape(i);
|
dims = layer->get_output_shape(i);
|
||||||
for (const auto &dim : dims) {
|
for (const auto &dim : dims) {
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include <ie_core.hpp>
|
#include <ie_core.hpp>
|
||||||
#include <net_pass.h>
|
#include <net_pass.h>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
#include <ngraph/function.hpp>
|
#include <ngraph/function.hpp>
|
||||||
#include <ngraph/variant.hpp>
|
#include <ngraph/variant.hpp>
|
||||||
#include <ngraph/op/maximum.hpp>
|
#include <ngraph/op/maximum.hpp>
|
||||||
@ -677,4 +678,25 @@ TEST(CNNNGraphImplTests, TestCheckStats) {
|
|||||||
InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
|
InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CNNNGraphImplTests, CanSetBatchReadValue) {
|
||||||
|
std::shared_ptr<ngraph::Function> ngraph;
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2});
|
||||||
|
auto constant = std::make_shared<ngraph::opset3::Constant>(ngraph::element::f32, ngraph::Shape{1, 2},
|
||||||
|
std::vector<float>{1, 2});
|
||||||
|
|
||||||
|
auto read_value = std::make_shared<ngraph::opset3::ReadValue>(constant, "variable_id");
|
||||||
|
auto add = std::make_shared<ngraph::opset3::Add>(input, read_value);
|
||||||
|
auto result = std::make_shared<ngraph::op::Result>(add);
|
||||||
|
|
||||||
|
ngraph::ParameterVector params = {input};
|
||||||
|
ngraph::ResultVector results = {result};
|
||||||
|
|
||||||
|
ngraph = std::make_shared<ngraph::Function>(results, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
|
||||||
|
auto status = cnnNet.getCNNNetwork()->setBatchSize(4, nullptr);
|
||||||
|
EXPECT_EQ(status, StatusCode::OK);
|
||||||
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
IE_SUPPRESS_DEPRECATED_END
|
||||||
|
Loading…
Reference in New Issue
Block a user