[GNA] Enabled the new tests (#9753)

* Enabled the new tests

* Enabled unsupported precisions

* Added additional checks
This commit is contained in:
Mikhail Ryzhov 2022-01-24 13:41:14 +03:00 committed by GitHub
parent 3437a2c9e7
commit aaf4ba0b51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 1 deletions

View File

@ -33,6 +33,44 @@ std::vector<ov::element::Type> prcs = {
ov::element::u64,
};
std::vector<ov::element::Type> supported_input_prcs = {
ov::element::f32,
ov::element::i16,
ov::element::u8
};
class OVInferRequestCheckTensorPrecisionGNA : public OVInferRequestCheckTensorPrecision {
public:
void SetUp() override {
try {
OVInferRequestCheckTensorPrecision::SetUp();
if (std::count(supported_input_prcs.begin(), supported_input_prcs.end(), element_type) == 0) {
FAIL() << "Precision " << element_type.c_type_string() << " is marked as unsupported but the network was loaded successfully";
}
}
catch (std::runtime_error& e) {
const std::string errorMsg = e.what();
const auto expectedMsg = exp_error_str_;
ASSERT_STR_CONTAINS(errorMsg, expectedMsg);
EXPECT_TRUE(errorMsg.find(expectedMsg) != std::string::npos)
<< "Wrong error message, actual error message: " << errorMsg
<< ", expected: " << expectedMsg;
if (std::count(supported_input_prcs.begin(), supported_input_prcs.end(), element_type) == 0) {
GTEST_SKIP_(expectedMsg.c_str());
} else {
FAIL() << "Precision " << element_type.c_type_string() << " is marked as supported but the network was not loaded";
}
}
}
private:
std::string exp_error_str_ = "The plugin does not support input precision";
};
TEST_P(OVInferRequestCheckTensorPrecisionGNA, CheckInputsOutputs) {
Run();
}
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferRequestIOTensorTest,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_GNA),
@ -46,4 +84,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferRequestIOTensorSetPrecision
::testing::ValuesIn(configs)),
OVInferRequestIOTensorSetPrecisionTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferRequestCheckTensorPrecisionGNA,
::testing::Combine(
::testing::ValuesIn(prcs),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs)),
OVInferRequestCheckTensorPrecisionGNA::getTestCaseName);
} // namespace

View File

@ -49,6 +49,7 @@ struct OVInferRequestCheckTensorPrecision : public testing::WithParamInterface<O
static std::string getTestCaseName(const testing::TestParamInfo<OVInferRequestCheckTensorPrecisionParams>& obj);
void SetUp() override;
void TearDown() override;
void Run();
std::shared_ptr<ov::Core> core = utils::PluginCache::get().core();
std::shared_ptr<ov::Model> model;

View File

@ -267,7 +267,7 @@ void OVInferRequestCheckTensorPrecision::TearDown() {
req = {};
}
TEST_P(OVInferRequestCheckTensorPrecision, CheckInputsOutputs) {
void OVInferRequestCheckTensorPrecision::Run() {
EXPECT_EQ(element_type, compModel.input(0).get_element_type());
EXPECT_EQ(element_type, compModel.input(1).get_element_type());
EXPECT_EQ(element_type, compModel.output().get_element_type());
@ -275,6 +275,11 @@ TEST_P(OVInferRequestCheckTensorPrecision, CheckInputsOutputs) {
EXPECT_EQ(element_type, req.get_input_tensor(1).get_element_type());
EXPECT_EQ(element_type, req.get_output_tensor().get_element_type());
}
TEST_P(OVInferRequestCheckTensorPrecision, CheckInputsOutputs) {
Run();
}
} // namespace behavior
} // namespace test
} // namespace ov