[GNA] Replace log::warning() with THROW_GNA_EXCEPTION for unsupported Concat (#16144)
This commit is contained in:
@@ -897,7 +897,7 @@ bool AreLayersSupported(InferenceEngine::CNNNetwork& network, std::string& errMe
|
||||
}
|
||||
} else if (info.isConcat()) {
|
||||
if (!ValidateConcatAxis(layer, errMessage)) {
|
||||
log::warning() << errMessage;
|
||||
THROW_GNA_EXCEPTION << errMessage;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1139,9 +1139,9 @@ void GNAGraphCompiler::ConcatPrimitive(InferenceEngine::CNNLayerPtr layer) {
|
||||
std::ostringstream in_dims_oss;
|
||||
auto in_dims = concatLayer->insData[0].lock()->getDims();
|
||||
std::copy(in_dims.begin(), in_dims.end(), std::ostream_iterator<size_t>(in_dims_oss, ","));
|
||||
log::warning() << "Topology with layer: " + layer->name + ", type: " + layer->type +
|
||||
", and concatenation axis(" + std::to_string(concatLayer->_axis) +
|
||||
") for input dimensions(" + in_dims_oss.str() + ") not supported\n";
|
||||
THROW_GNA_EXCEPTION << "Topology with layer: " + layer->name + ", type: " + layer->type +
|
||||
", and concatenation axis(" + std::to_string(concatLayer->_axis) +
|
||||
") for input dimensions(" + in_dims_oss.str() + ") not supported\n";
|
||||
}
|
||||
|
||||
auto& concatLayerInfo = concat_connection.find(concatLayer->name)->second;
|
||||
|
||||
@@ -269,17 +269,6 @@ public:
|
||||
static const char* getMatch() {
|
||||
return T::getMatch();
|
||||
}
|
||||
void test_output() {
|
||||
std::stringstream what;
|
||||
std::streambuf* sbuf = std::cout.rdbuf();
|
||||
std::streambuf* ebuf = std::cerr.rdbuf();
|
||||
std::cout.rdbuf(what.rdbuf());
|
||||
std::cerr.rdbuf(what.rdbuf());
|
||||
LoadNetwork();
|
||||
EXPECT_TRUE(what.str().find(getMatch()) != std::string::npos);
|
||||
std::cout.rdbuf(sbuf);
|
||||
std::cerr.rdbuf(ebuf);
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
@@ -303,7 +292,7 @@ using ConvConcatNHWCRestrictionsNeg = ConcatRestrictions<ConvConcatNHWCAxis>;
|
||||
using ConvConcatNHWCRestrictionsPos = ConcatRestrictions<ConvConcatNHWCAxis>;
|
||||
|
||||
TEST_P(ReLUConcatRestrictionsNeg, CompareWithRefImpl) {
|
||||
test_output();
|
||||
ExpectLoadNetworkToThrow(getMatch());
|
||||
};
|
||||
|
||||
// TODO: this test is left for future when GNA plugin handles const tranposition required for concats with interleaved
|
||||
@@ -313,8 +302,7 @@ TEST_P(ReLUConcatRestrictionsNeg, CompareWithRefImpl) {
|
||||
//};
|
||||
|
||||
TEST_P(MatMulConcatRestrictionsNeg, CompareWithRefImpl) {
|
||||
test_output();
|
||||
;
|
||||
ExpectLoadNetworkToThrow(getMatch());
|
||||
};
|
||||
|
||||
TEST_P(MatMulConcatRestrictionsPos, CompareWithRefImpl) {
|
||||
@@ -322,13 +310,7 @@ TEST_P(MatMulConcatRestrictionsPos, CompareWithRefImpl) {
|
||||
};
|
||||
|
||||
TEST_P(ConvNCHWConcatRestrictionsNeg, CompareWithRefImpl) {
|
||||
std::string what;
|
||||
try {
|
||||
LoadNetwork();
|
||||
} catch (const std::exception& e) {
|
||||
what.assign(e.what());
|
||||
}
|
||||
EXPECT_TRUE(what.find(getMatch()) != std::string::npos);
|
||||
ExpectLoadNetworkToThrow(getMatch());
|
||||
};
|
||||
|
||||
TEST_P(ConvNCHWConcatRestrictionsPos, CompareWithRefImpl) {
|
||||
@@ -336,7 +318,7 @@ TEST_P(ConvNCHWConcatRestrictionsPos, CompareWithRefImpl) {
|
||||
};
|
||||
|
||||
TEST_P(ConvNHWCConcatRestrictionsNeg, CompareWithRefImpl) {
|
||||
test_output();
|
||||
ExpectLoadNetworkToThrow(getMatch());
|
||||
};
|
||||
|
||||
TEST_P(ConvNHWCConcatRestrictionsPos, CompareWithRefImpl) {
|
||||
@@ -344,7 +326,7 @@ TEST_P(ConvNHWCConcatRestrictionsPos, CompareWithRefImpl) {
|
||||
};
|
||||
|
||||
TEST_P(ConvConcatNHWCRestrictionsNeg, CompareWithRefImpl) {
|
||||
test_output();
|
||||
ExpectLoadNetworkToThrow(getMatch());
|
||||
};
|
||||
|
||||
TEST_P(ConvConcatNHWCRestrictionsPos, CompareWithRefImpl) {
|
||||
@@ -352,8 +334,7 @@ TEST_P(ConvConcatNHWCRestrictionsPos, CompareWithRefImpl) {
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32};
|
||||
const std::vector<std::map<std::string, std::string>> configs = {
|
||||
{{"GNA_DEVICE_MODE", "GNA_SW_FP32"}, {"LOG_LEVEL", "LOG_WARNING"}}};
|
||||
const std::vector<std::map<std::string, std::string>> configs = {{{"GNA_DEVICE_MODE", "GNA_SW_FP32"}}};
|
||||
|
||||
// Negative 4D MatMul cases
|
||||
const std::vector<std::vector<size_t>> inputShapesMatMul4D_neg = {{1, 2, 4, 8}};
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
using namespace SubgraphTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<std::vector<size_t>> inputs{{{16, 2}}, {{8, 2}}, {{1, 8}}, {{8, 1}}};
|
||||
std::vector<std::vector<size_t>> inputs1{{{1, 8}}, {{8, 1}}};
|
||||
std::vector<std::vector<size_t>> inputs2{{{16, 2}}, {{8, 2}}};
|
||||
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
@@ -18,7 +19,14 @@ std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_concat_permute,
|
||||
PermuteConcatConcatPermute,
|
||||
::testing::Combine(::testing::ValuesIn(inputs),
|
||||
::testing::Combine(::testing::ValuesIn(inputs1),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA)),
|
||||
PermuteConcatConcatPermute::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_concat_permute,
|
||||
PermuteConcatConcatPermuteNeg,
|
||||
::testing::Combine(::testing::ValuesIn(inputs2),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA)),
|
||||
PermuteConcatConcatPermute::getTestCaseName);
|
||||
|
||||
@@ -10,9 +10,12 @@ using namespace SubgraphTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<std::vector<std::vector<size_t>>> inputs{
|
||||
{{1, 8}, {1, 0}, {1, 0}},
|
||||
};
|
||||
|
||||
std::vector<std::vector<std::vector<size_t>>> inputsNeg{
|
||||
{{32, 2}, {1, 0}, {1, 0}},
|
||||
{{8, 2}, {1, 0}, {1, 0}},
|
||||
{{1, 8}, {1, 0}, {1, 0}},
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
@@ -26,4 +29,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_permute,
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA)),
|
||||
PermuteConcatPermute::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_permute,
|
||||
PermuteConcatPermuteNeg,
|
||||
::testing::Combine(::testing::ValuesIn(inputsNeg),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA)),
|
||||
PermuteConcatPermute::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -8,8 +8,14 @@
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using PermuteConcatConcatPermuteNeg = PermuteConcatConcatPermute;
|
||||
|
||||
TEST_P(PermuteConcatConcatPermute, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(PermuteConcatConcatPermuteNeg, CompareWithRefs) {
|
||||
ExpectLoadNetworkToThrow("type: Concat, and concatenation axis(");
|
||||
}
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
||||
@@ -8,8 +8,14 @@
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using PermuteConcatPermuteNeg = PermuteConcatPermute;
|
||||
|
||||
TEST_P(PermuteConcatPermute, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(PermuteConcatPermuteNeg, CompareWithRefs) {
|
||||
ExpectLoadNetworkToThrow("type: Concat, and concatenation axis(");
|
||||
}
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
||||
@@ -144,6 +144,8 @@ protected:
|
||||
|
||||
virtual void LoadNetwork();
|
||||
|
||||
virtual void ExpectLoadNetworkToThrow(const std::string& msg);
|
||||
|
||||
virtual void GenerateInputs();
|
||||
|
||||
virtual void ConfigureInferRequest();
|
||||
|
||||
@@ -365,6 +365,16 @@ void LayerTestsCommon::LoadNetwork() {
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
}
|
||||
|
||||
void LayerTestsCommon::ExpectLoadNetworkToThrow(const std::string& msg) {
|
||||
std::string what;
|
||||
try {
|
||||
LoadNetwork();
|
||||
} catch (const std::exception& e) {
|
||||
what.assign(e.what());
|
||||
}
|
||||
EXPECT_STR_CONTAINS(what.c_str(), msg.c_str());
|
||||
}
|
||||
|
||||
void LayerTestsCommon::GenerateInputs() {
|
||||
inputs.clear();
|
||||
const auto& inputsInfo = executableNetwork.GetInputsInfo();
|
||||
|
||||
Reference in New Issue
Block a user