[GNA] Fixed issue for concat connection to memory layer (#6985) (#7058)

* [GNA] Fixed issue for concat connection to memory layer (#6985)

* Fix for concat connection to memory layer

* reverted merge files

* Replaced opset8
This commit is contained in:
Mikhail Ryzhov
2021-08-16 11:04:29 +03:00
committed by GitHub
parent 9a8d8440a5
commit fb3ceb6aa4
5 changed files with 33 additions and 6 deletions

View File

@@ -1031,8 +1031,8 @@ class ScaleFactorPerLayer<InferenceEngine::ConcatLayer*> {
}
quantDataForConCatInput->_dst_quant.SetScale(newScaleFactor);
} else if (restarLayerInfo.isConst()) {
gnalog() << "... warning const layer will be requantized\n";
} else if (restarLayerInfo.isConst() || restarLayerInfo.isMemory()) {
gnalog() << "... warning " << restartedLayer->type << " layer will be requantized\n";
quantDataForConCatInput->_src_quant.SetScale(sourceQuantParams->_dst_quant.GetScale());
quantDataForConCatInput->_dst_quant.SetScale(sourceQuantParams->_dst_quant.GetScale());
} else {

View File

@@ -2238,8 +2238,8 @@ void GNAGraphCompiler::connectOutput(InferenceEngine::CNNLayerPtr layer, void *p
nextMemoryLayer.reserved_size = ALIGN64(memorySize);
} else {
IE_ASSERT(nextMemoryLayer.reserved_size >= ALIGN64(num_data_bytes_out));
gnamem->bind_ptr(ptr, &nextMemoryLayer.gna_ptr, getOffsetForBinding(layer));
// We may need to extend memory buffer if connected input size is bigger, for example for concat connection
gnamem->bind_ptr(ptr, &nextMemoryLayer.gna_ptr, getOffsetForBinding(layer), ALIGN64(num_data_bytes_out));
}
return;
}
@@ -2524,8 +2524,8 @@ GNAPluginNS::ConnectionDetails GNAGraphCompiler::connectInput(CNNLayerPtr layer,
memoryLayer.reserved_size = ALIGN64(memorySize);
} else {
IE_ASSERT(memoryLayer.reserved_size >= ALIGN64(num_data_bytes_in));
gnamem->bind_ptr(ptr, &memoryLayer.gna_ptr, offset);
// We may need to extend memory buffer if connected input size is bigger, for example for concat connection
gnamem->bind_ptr(ptr, &memoryLayer.gna_ptr, offset, ALIGN64(num_data_bytes_in));
}
return prevLayer;

View File

@@ -18,4 +18,11 @@ TEST_P(ConcatMultiInput, CompareWithRefConstOnly) {
Run();
};
TEST_P(ConcatMultiInput, CompareWithRefMemory) {
GenerateMemoryModel();
LoadNetwork();
GenerateInputs();
Infer();
};
} // namespace SubgraphTestsDefinitions

View File

@@ -32,6 +32,7 @@ private:
public:
void GenerateStridedSliceModel();
void GenerateConstOnlyModel();
void GenerateMemoryModel();
static std::string getTestCaseName(testing::TestParamInfo<concatMultiParams> obj);
protected:

View File

@@ -105,4 +105,23 @@ void ConcatMultiInput::GenerateConstOnlyModel() {
function = std::make_shared<ngraph::Function>(results, input_vector, "ConcatConstOnly");
}
void ConcatMultiInput::GenerateMemoryModel() {
int axis = 1;
auto input = ngraph::builder::makeParams(ngPrc, { inputShapes[0] });
auto variable = std::make_shared<ngraph::Variable>(ngraph::VariableInfo{ngraph::PartialShape::dynamic(), ngraph::element::dynamic, "concat_input_memory"});
auto mem_i = std::make_shared<ngraph::opset7::Constant>(ngPrc, inputShapes[0]);
auto mem_r = std::make_shared<ngraph::opset7::ReadValue>(mem_i, variable);
ngraph::OutputVector concat_input;
concat_input.push_back(mem_r);
concat_input.push_back(input.at(0));
auto concat = std::make_shared<ngraph::opset7::Concat>(concat_input, axis);
auto mem_w = std::make_shared<ngraph::opset7::Assign>(input.at(0), variable);
auto res = std::make_shared<ngraph::opset7::Result>(concat);
function = std::make_shared<ngraph::Function>(ngraph::ResultVector{res}, ngraph::SinkVector{mem_w}, input, "ConcatMemory");
}
} // namespace SubgraphTestsDefinitions