[GNA] Fixed cascade concats binding (#11326)

This commit is contained in:
Elizaveta Lobanova 2022-04-04 15:56:13 +03:00 committed by GitHub
parent da8388e263
commit 9b4e8f5b59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 133 additions and 15 deletions

View File

@ -996,21 +996,27 @@ void GNAGraphCompiler::ConcatPrimitive(InferenceEngine::CNNLayerPtr layer) {
}
auto& concatLayerInfo = concat_connection.find(concatLayer->name)->second;
std::function<InferenceEngine::CNNLayerPtr(InferenceEngine::CNNLayerPtr)> find_cascaded_concat_recursively =
[&find_cascaded_concat_recursively](InferenceEngine::CNNLayerPtr concat_candidate) {
if (LayerInfo(concat_candidate).isConcat()) {
return concat_candidate;
}
if (!LayerInfo(concat_candidate).isNonFunctional()) {
return InferenceEngine::CNNLayerPtr(nullptr);
}
for (auto &&child_layer : getInputTo(concat_candidate->outData.front())) {
auto child_concat = find_cascaded_concat_recursively(child_layer.second);
if (child_concat) return child_concat;
}
return InferenceEngine::CNNLayerPtr(nullptr);
};
for (auto &&outLayer : getInputTo(concatLayer->outData.front())) {
auto concatCandidate = outLayer.second;
if (LayerInfo(concatCandidate).isNonFunctional()) {
// searching for next concat
auto isNonFunctional = [](CNNLayerPtr l) {
return LayerInfo(l).isNonFunctional();
};
if (!CNNNetHasNextLayerSkipCertain(concatCandidate, 0, 0, isNonFunctional)) {
continue;
}
concatCandidate = CNNNetGetNextLayerSkipCertain(concatCandidate, 0, 0, isNonFunctional).first;
}
if (!LayerInfo(concatCandidate).isConcat()) {
continue;
}
auto concatCandidate = find_cascaded_concat_recursively(outLayer.second);
if (!concatCandidate) continue;
gnalog() << "Cascaded concat connection found from: " << layer->name << ", to: " << concatCandidate->name << std::endl;
connectOutput(layer, &concatLayerInfo.gna_ptr, concatLayerInfo.reserved_size);
}

View File

@ -34,7 +34,15 @@ namespace {
{{1, 8}}
};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
std::vector<std::vector<size_t>> shape_one_input{
{1, 64},
{1, 128},
{1, 32}
};
std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
std::map<std::string, std::string> additional_config = {
@ -43,6 +51,15 @@ namespace {
{"GNA_SCALE_FACTOR_2", "1"}
};
std::vector<std::map<std::string, std::string>> additional_config_one_input = {
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_FP32"}
}
};
INSTANTIATE_TEST_SUITE_P(smoke_cascade_concat, CascadeConcat,
::testing::Combine(
::testing::ValuesIn(shape1),
@ -64,4 +81,12 @@ namespace {
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::Values(additional_config)),
CascadeConcat::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_cascade_concat_reshape, CascadeConcatWithMultiConnReshape,
::testing::Combine(
::testing::ValuesIn(shape_one_input),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(additional_config_one_input)),
CascadeConcatWithMultiConnReshape::getTestCaseName);
} // namespace

View File

@ -12,4 +12,8 @@ TEST_P(CascadeConcat, CompareWithRefs) {
Run();
}
TEST_P(CascadeConcatWithMultiConnReshape, CompareWithRefs) {
Run();
}
} // namespace SubgraphTestsDefinitions

View File

@ -30,4 +30,21 @@ public:
protected:
void SetUp() override;
};
typedef std::tuple<
std::vector<size_t>, //input shapes
InferenceEngine::Precision, //Network precision
std::string, //Device name
std::map<std::string, std::string> //config
> CascadeConcatWithMultiConnReshapeTuple;
class CascadeConcatWithMultiConnReshape
: public testing::WithParamInterface<CascadeConcatWithMultiConnReshapeTuple>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<CascadeConcatWithMultiConnReshapeTuple> &obj);
protected:
void SetUp() override;
};
} // namespace SubgraphTestsDefinitions

View File

@ -56,4 +56,70 @@ void CascadeConcat::SetUp() {
}
function = std::make_shared<ngraph::Function>(results, input, "concat_reshape_reshape_concat_mul");
}
std::string CascadeConcatWithMultiConnReshape::getTestCaseName(const testing::TestParamInfo<CascadeConcatWithMultiConnReshapeTuple> &obj) {
std::vector<size_t> inputShape;
InferenceEngine::Precision netPrecision;
std::string targetName;
std::map<std::string, std::string> additional_config;
std::tie(inputShape, netPrecision, targetName, additional_config) = obj.param;
std::ostringstream results;
results << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
results << "netPRC=" << netPrecision.name() << "_";
results << "targetDevice=" << targetName << "_";
for (auto const& configItem : additional_config) {
results << "_configItem=" << configItem.first << "_" << configItem.second;
}
return results.str();
}
/**
* Tests a case when 2 concats have Squeeze between them and Concat2 is the second connection of Squeeze output
* Input Const1
* | |
* Relu |
* | |
* Concat1
* |
* Squeeze Const2
* | | |
* Relu1 Concat2
* | |
* Unsqueeze1 Relu2
* |
* Unsqueeze2
*/
void CascadeConcatWithMultiConnReshape::SetUp() {
std::vector<size_t> inputShape;
InferenceEngine::Precision netPrecision;
std::map<std::string, std::string> additional_config;
std::tie(inputShape, netPrecision, targetDevice, additional_config) = this->GetParam();
configuration.insert(additional_config.begin(), additional_config.end());
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto inputShapeSqueezed = inputShape;
inputShapeSqueezed.insert(std::begin(inputShapeSqueezed), 1);
auto input = ngraph::builder::makeParams(ngPrc, {inputShapeSqueezed});
auto relu = std::make_shared<ngraph::opset8::Relu>(input[0]);
auto const1 = ngraph::builder::makeConstant(ngPrc, inputShapeSqueezed, std::vector<float>{}, true);
auto concat1 = ngraph::builder::makeConcat({relu, const1}, inputShapeSqueezed.size() - 1);
auto squeeze = ngraph::builder::makeSqueezeUnsqueeze(concat1, ngraph::element::i64, {0}, ngraph::helpers::SqueezeOpType::SQUEEZE);
auto relu1 = std::make_shared<ngraph::opset8::Relu>(squeeze);
auto unsqueeze1 = ngraph::builder::makeSqueezeUnsqueeze(relu1, ngraph::element::i64, {0}, ngraph::helpers::SqueezeOpType::UNSQUEEZE);
auto const2 = ngraph::builder::makeConstant(ngPrc, inputShape, std::vector<float>{}, true);
auto concat2 = ngraph::builder::makeConcat({squeeze, const2}, 1);
// Change concat name to make it the second connection in the map of squeeze output connections
concat2->set_friendly_name("XConcat");
auto relu2 = std::make_shared<ngraph::opset8::Relu>(concat2);
auto unsqueeze2 = ngraph::builder::makeSqueezeUnsqueeze(relu2, ngraph::element::i64, {0}, ngraph::helpers::SqueezeOpType::UNSQUEEZE);
ngraph::ResultVector results = {std::make_shared<ngraph::opset1::Result>(unsqueeze1),
std::make_shared<ngraph::opset1::Result>(unsqueeze2)};
function = std::make_shared<ngraph::Function>(results, input, "CascadeConcatWithMultiConnReshapeTest");
}
} // namespace SubgraphTestsDefinitions