[GNA] Fixed cascade concats binding (#11326)
This commit is contained in:
parent
da8388e263
commit
9b4e8f5b59
@ -996,21 +996,27 @@ void GNAGraphCompiler::ConcatPrimitive(InferenceEngine::CNNLayerPtr layer) {
|
||||
}
|
||||
|
||||
auto& concatLayerInfo = concat_connection.find(concatLayer->name)->second;
|
||||
std::function<InferenceEngine::CNNLayerPtr(InferenceEngine::CNNLayerPtr)> find_cascaded_concat_recursively =
|
||||
[&find_cascaded_concat_recursively](InferenceEngine::CNNLayerPtr concat_candidate) {
|
||||
if (LayerInfo(concat_candidate).isConcat()) {
|
||||
return concat_candidate;
|
||||
}
|
||||
|
||||
if (!LayerInfo(concat_candidate).isNonFunctional()) {
|
||||
return InferenceEngine::CNNLayerPtr(nullptr);
|
||||
}
|
||||
|
||||
for (auto &&child_layer : getInputTo(concat_candidate->outData.front())) {
|
||||
auto child_concat = find_cascaded_concat_recursively(child_layer.second);
|
||||
if (child_concat) return child_concat;
|
||||
}
|
||||
|
||||
return InferenceEngine::CNNLayerPtr(nullptr);
|
||||
};
|
||||
|
||||
for (auto &&outLayer : getInputTo(concatLayer->outData.front())) {
|
||||
auto concatCandidate = outLayer.second;
|
||||
if (LayerInfo(concatCandidate).isNonFunctional()) {
|
||||
// searching for next concat
|
||||
auto isNonFunctional = [](CNNLayerPtr l) {
|
||||
return LayerInfo(l).isNonFunctional();
|
||||
};
|
||||
if (!CNNNetHasNextLayerSkipCertain(concatCandidate, 0, 0, isNonFunctional)) {
|
||||
continue;
|
||||
}
|
||||
concatCandidate = CNNNetGetNextLayerSkipCertain(concatCandidate, 0, 0, isNonFunctional).first;
|
||||
}
|
||||
if (!LayerInfo(concatCandidate).isConcat()) {
|
||||
continue;
|
||||
}
|
||||
auto concatCandidate = find_cascaded_concat_recursively(outLayer.second);
|
||||
if (!concatCandidate) continue;
|
||||
gnalog() << "Cascaded concat connection found from: " << layer->name << ", to: " << concatCandidate->name << std::endl;
|
||||
connectOutput(layer, &concatLayerInfo.gna_ptr, concatLayerInfo.reserved_size);
|
||||
}
|
||||
|
@ -34,7 +34,15 @@ namespace {
|
||||
{{1, 8}}
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
std::vector<std::vector<size_t>> shape_one_input{
|
||||
{1, 64},
|
||||
{1, 128},
|
||||
{1, 32}
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
|
||||
std::map<std::string, std::string> additional_config = {
|
||||
@ -43,6 +51,15 @@ namespace {
|
||||
{"GNA_SCALE_FACTOR_2", "1"}
|
||||
};
|
||||
|
||||
std::vector<std::map<std::string, std::string>> additional_config_one_input = {
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_FP32"}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_cascade_concat, CascadeConcat,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(shape1),
|
||||
@ -64,4 +81,12 @@ namespace {
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::Values(additional_config)),
|
||||
CascadeConcat::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_cascade_concat_reshape, CascadeConcatWithMultiConnReshape,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(shape_one_input),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(additional_config_one_input)),
|
||||
CascadeConcatWithMultiConnReshape::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -12,4 +12,8 @@ TEST_P(CascadeConcat, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(CascadeConcatWithMultiConnReshape, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
@ -30,4 +30,21 @@ public:
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<size_t>, //input shapes
|
||||
InferenceEngine::Precision, //Network precision
|
||||
std::string, //Device name
|
||||
std::map<std::string, std::string> //config
|
||||
> CascadeConcatWithMultiConnReshapeTuple;
|
||||
|
||||
class CascadeConcatWithMultiConnReshape
|
||||
: public testing::WithParamInterface<CascadeConcatWithMultiConnReshapeTuple>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<CascadeConcatWithMultiConnReshapeTuple> &obj);
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
@ -56,4 +56,70 @@ void CascadeConcat::SetUp() {
|
||||
}
|
||||
function = std::make_shared<ngraph::Function>(results, input, "concat_reshape_reshape_concat_mul");
|
||||
}
|
||||
|
||||
std::string CascadeConcatWithMultiConnReshape::getTestCaseName(const testing::TestParamInfo<CascadeConcatWithMultiConnReshapeTuple> &obj) {
|
||||
std::vector<size_t> inputShape;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetName;
|
||||
std::map<std::string, std::string> additional_config;
|
||||
std::tie(inputShape, netPrecision, targetName, additional_config) = obj.param;
|
||||
std::ostringstream results;
|
||||
|
||||
results << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
|
||||
results << "netPRC=" << netPrecision.name() << "_";
|
||||
results << "targetDevice=" << targetName << "_";
|
||||
for (auto const& configItem : additional_config) {
|
||||
results << "_configItem=" << configItem.first << "_" << configItem.second;
|
||||
}
|
||||
return results.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests a case when 2 concats have Squeeze between them and Concat2 is the second connection of Squeeze output
|
||||
* Input Const1
|
||||
* | |
|
||||
* Relu |
|
||||
* | |
|
||||
* Concat1
|
||||
* |
|
||||
* Squeeze Const2
|
||||
* | | |
|
||||
* Relu1 Concat2
|
||||
* | |
|
||||
* Unsqueeze1 Relu2
|
||||
* |
|
||||
* Unsqueeze2
|
||||
*/
|
||||
void CascadeConcatWithMultiConnReshape::SetUp() {
|
||||
std::vector<size_t> inputShape;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::map<std::string, std::string> additional_config;
|
||||
std::tie(inputShape, netPrecision, targetDevice, additional_config) = this->GetParam();
|
||||
configuration.insert(additional_config.begin(), additional_config.end());
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto inputShapeSqueezed = inputShape;
|
||||
inputShapeSqueezed.insert(std::begin(inputShapeSqueezed), 1);
|
||||
auto input = ngraph::builder::makeParams(ngPrc, {inputShapeSqueezed});
|
||||
auto relu = std::make_shared<ngraph::opset8::Relu>(input[0]);
|
||||
auto const1 = ngraph::builder::makeConstant(ngPrc, inputShapeSqueezed, std::vector<float>{}, true);
|
||||
auto concat1 = ngraph::builder::makeConcat({relu, const1}, inputShapeSqueezed.size() - 1);
|
||||
|
||||
auto squeeze = ngraph::builder::makeSqueezeUnsqueeze(concat1, ngraph::element::i64, {0}, ngraph::helpers::SqueezeOpType::SQUEEZE);
|
||||
|
||||
auto relu1 = std::make_shared<ngraph::opset8::Relu>(squeeze);
|
||||
auto unsqueeze1 = ngraph::builder::makeSqueezeUnsqueeze(relu1, ngraph::element::i64, {0}, ngraph::helpers::SqueezeOpType::UNSQUEEZE);
|
||||
|
||||
auto const2 = ngraph::builder::makeConstant(ngPrc, inputShape, std::vector<float>{}, true);
|
||||
auto concat2 = ngraph::builder::makeConcat({squeeze, const2}, 1);
|
||||
// Change concat name to make it the second connection in the map of squeeze output connections
|
||||
concat2->set_friendly_name("XConcat");
|
||||
|
||||
auto relu2 = std::make_shared<ngraph::opset8::Relu>(concat2);
|
||||
auto unsqueeze2 = ngraph::builder::makeSqueezeUnsqueeze(relu2, ngraph::element::i64, {0}, ngraph::helpers::SqueezeOpType::UNSQUEEZE);
|
||||
ngraph::ResultVector results = {std::make_shared<ngraph::opset1::Result>(unsqueeze1),
|
||||
std::make_shared<ngraph::opset1::Result>(unsqueeze2)};
|
||||
|
||||
function = std::make_shared<ngraph::Function>(results, input, "CascadeConcatWithMultiConnReshapeTest");
|
||||
}
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user