[CompileModelCacheTestBase] Avoid tests with ops and them unsupported precision (#14582)

* [CompileModelCacheTestBase] Avoid tests with ops and them unsupport presicion

* Update functional and plugins tests
This commit is contained in:
Sofya Balandina 2022-12-21 10:23:03 +00:00 committed by GitHub
parent 20ca25bfca
commit 7856045497
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 258 additions and 45 deletions

View File

@ -22,6 +22,11 @@ namespace {
ngraph::element::u16,
};
static const std::vector<ngraph::element::Type> floatPrecisionsCPU = {
ngraph::element::f32,
ngraph::element::f16,
};
static const std::vector<std::size_t> batchSizesCPU = {
1, 2
};
@ -81,13 +86,22 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(CompileModelCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(precisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU_Float, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatPrecisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU_Internal, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(internal_functions_cpu()),
@ -103,22 +117,40 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_CachingSupportCase, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(CompileModelCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(precisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(autoConfigs)),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_CachingSupportCase_Float, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatPrecisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(autoConfigs)),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Auto_CachingSupportCase_CPU, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(CompileModelCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(precisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_AUTO),
::testing::ValuesIn(autoConfigs)),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Auto_CachingSupportCase_CPU_Float, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatPrecisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_AUTO),
::testing::ValuesIn(autoConfigs)),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Auto_CachingSupportCase_CPU_Internal, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(internal_functions_cpu()),

View File

@ -22,6 +22,11 @@ namespace {
ngraph::element::u16,
};
static const std::vector<ngraph::element::Type> floatPrecisionsCPU = {
ngraph::element::f32,
ngraph::element::f16
};
static const std::vector<std::size_t> batchSizesCPU = {
1, 2
};
@ -81,12 +86,20 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(LoadNetworkCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(precisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU_Float, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatPrecisionsCPU),
::testing::ValuesIn(batchSizesCPU),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU_Internal, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(internal_functions_cpu()),

View File

@ -30,12 +30,40 @@ static const std::vector<std::size_t> ovBatchSizesTemplate = {
1, 2
};
static const std::vector<ov::element::Type> ovElemAnyNumericTypesTemplate(ovElemTypesTemplate.begin(),
ovElemTypesTemplate.end() - 1);
static const std::vector<ov::element::Type> ovElemAnyFloatingPointTypesTemplate(ovElemTypesTemplate.begin(),
ovElemTypesTemplate.begin() + 3);
INSTANTIATE_TEST_SUITE_P(ov_plugin, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(CompileModelCacheTestBase::getAnyTypeOnlyFunctions()),
::testing::ValuesIn(ovElemTypesTemplate),
::testing::ValuesIn(ovBatchSizesTemplate),
::testing::ValuesIn(return_all_possible_device_combination()),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
// Convolution/UnaryElementwiseArithmetic/BinaryElementwiseArithmetic is not supported boolean elemnt type
INSTANTIATE_TEST_SUITE_P(ov_plugin_numeric, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getNumericTypeOnlyFunctions()),
::testing::ValuesIn(ovElemAnyNumericTypesTemplate),
::testing::ValuesIn(ovBatchSizesTemplate),
::testing::ValuesIn(return_all_possible_device_combination()),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
// LSTMcell supported floating-point element type
INSTANTIATE_TEST_SUITE_P(ov_plugin_floating_point, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(ovElemAnyFloatingPointTypesTemplate),
::testing::ValuesIn(ovBatchSizesTemplate),
::testing::ValuesIn(return_all_possible_device_combination()),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
} // namespace

View File

@ -30,11 +30,33 @@ static const std::vector<std::size_t> batchSizesTemplate = {
1, 2
};
INSTANTIATE_TEST_SUITE_P(ie_plugin, LoadNetworkCacheTestBase,
static const std::vector<ov::element::Type> numericPrecisionsTemplate(precisionsTemplate.begin(),
precisionsTemplate.end() - 1);
static const std::vector<ov::element::Type> floatingPointPrecisionsTemplate(precisionsTemplate.begin(),
precisionsTemplate.begin() + 3);
INSTANTIATE_TEST_SUITE_P(ie_plugin_any_type, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(LoadNetworkCacheTestBase::getAnyTypeOnlyFunctions()),
::testing::ValuesIn(precisionsTemplate),
::testing::ValuesIn(batchSizesTemplate),
::testing::ValuesIn(return_all_possible_device_combination())),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(ie_plugin_numeric, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getNumericTypeOnlyFunctions()),
::testing::ValuesIn(numericPrecisionsTemplate),
::testing::ValuesIn(batchSizesTemplate),
::testing::ValuesIn(return_all_possible_device_combination())),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(ie_plugin_float, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatingPointPrecisionsTemplate),
::testing::ValuesIn(batchSizesTemplate),
::testing::ValuesIn(return_all_possible_device_combination())),
LoadNetworkCacheTestBase::getTestCaseName);
} // namespace

View File

@ -22,15 +22,29 @@ namespace {
1, 2
};
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_GPU, CompileModelCacheTestBase,
static const std::vector<ov::element::Type> floatingPointPrecisionsGPU = {
ngraph::element::f32,
ngraph::element::f16,
};
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCaseAnyType_GPU, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(CompileModelCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(precisionsGPU),
::testing::ValuesIn(batchSizesGPU),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCaseFloat_GPU, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatingPointPrecisionsGPU),
::testing::ValuesIn(batchSizesGPU),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_KernelCachingSupportCase_GPU, CompiledKernelsCacheTest,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_GPU),

View File

@ -18,18 +18,31 @@ namespace {
ngraph::element::u16,
};
static const std::vector<ngraph::element::Type> floatPrecisionsGPU = {
ngraph::element::f32,
ngraph::element::f16
};
static const std::vector<std::size_t> batchSizesGPU = {
1, 2
};
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_GPU, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(LoadNetworkCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(precisionsGPU),
::testing::ValuesIn(batchSizesGPU),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_GPU_Float, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(floatPrecisionsGPU),
::testing::ValuesIn(batchSizesGPU),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_KernelCachingSupportCase_GPU, LoadNetworkCompiledKernelsCacheTest,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_GPU),

View File

@ -15,6 +15,11 @@ namespace {
ngraph::element::u8,
};
static const std::vector<ngraph::element::Type> nightly_floatPrecisionsMyriad = {
ngraph::element::f32,
ngraph::element::f16
};
static const std::vector<ngraph::element::Type> smoke_precisionsMyriad = {
ngraph::element::f32,
};
@ -42,10 +47,19 @@ namespace {
INSTANTIATE_TEST_SUITE_P(nightly_CachingSupportCase_Myriad, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(CompileModelCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(nightly_precisionsMyriad),
::testing::ValuesIn(batchSizesMyriad),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_CachingSupportCase_Myriad_Float, CompileModelCacheTestBase,
::testing::Combine(
::testing::ValuesIn(CompileModelCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(nightly_floatPrecisionsMyriad),
::testing::ValuesIn(batchSizesMyriad),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD),
::testing::Values(ov::AnyMap{})),
CompileModelCacheTestBase::getTestCaseName);
} // namespace

View File

@ -15,6 +15,11 @@ namespace {
ngraph::element::u8,
};
static const std::vector<ngraph::element::Type> nightly_floatPrecisionsMyriad = {
ngraph::element::f32,
ngraph::element::f16
};
static const std::vector<ngraph::element::Type> smoke_precisionsMyriad = {
ngraph::element::f32,
};
@ -41,9 +46,17 @@ namespace {
INSTANTIATE_TEST_SUITE_P(nightly_CachingSupportCase_Myriad, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getStandardFunctions()),
::testing::ValuesIn(LoadNetworkCacheTestBase::getNumericAnyTypeFunctions()),
::testing::ValuesIn(nightly_precisionsMyriad),
::testing::ValuesIn(batchSizesMyriad),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)),
LoadNetworkCacheTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_CachingSupportCase_Myriad_Float, LoadNetworkCacheTestBase,
::testing::Combine(
::testing::ValuesIn(LoadNetworkCacheTestBase::getFloatingPointOnlyFunctions()),
::testing::ValuesIn(nightly_floatPrecisionsMyriad),
::testing::ValuesIn(batchSizesMyriad),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)),
LoadNetworkCacheTestBase::getTestCaseName);
} // namespace

View File

@ -33,6 +33,9 @@ using compileModelCacheParams = std::tuple<
ov::AnyMap // device configuration
>;
using ovModelIS = std::function<std::shared_ptr<ov::Model>(std::vector<size_t> inputShape,
ov::element::Type_t type)>;
class CompileModelCacheTestBase : public testing::WithParamInterface<compileModelCacheParams>,
virtual public SubgraphBaseTest,
virtual public OVPluginTestBase {
@ -49,7 +52,14 @@ public:
void run() override;
bool importExportSupported(ov::Core &core) const;
// Wrapper of most part of available builder functions
static ovModelGenerator inputShapeWrapper(ovModelIS fun, std::vector<size_t> inputShape);
// Default functions and precisions that can be used as test parameters
static std::vector<ovModelWithName> getAnyTypeOnlyFunctions();
static std::vector<ovModelWithName> getNumericTypeOnlyFunctions();
static std::vector<ovModelWithName> getNumericAnyTypeFunctions();
static std::vector<ovModelWithName> getFloatingPointOnlyFunctions();
static std::vector<ovModelWithName> getStandardFunctions();
};

View File

@ -20,6 +20,8 @@
using ngraphFunctionGenerator = std::function<std::shared_ptr<ngraph::Function>(ngraph::element::Type, std::size_t)>;
using nGraphFunctionWithName = std::tuple<ngraphFunctionGenerator, std::string>;
using ngraphFunctionIS = std::function<std::shared_ptr<ngraph::Function>(std::vector<size_t> inputShape,
ngraph::element::Type_t type)>;
using loadNetworkCacheParams = std::tuple<
nGraphFunctionWithName, // ngraph function with friendly name
@ -45,7 +47,13 @@ public:
bool importExportSupported(InferenceEngine::Core& ie) const;
// Wrapper of most part of available builder functions
static ngraphFunctionGenerator inputShapeWrapper(ngraphFunctionIS fun, std::vector<size_t> inputShape);
// Default functions and precisions that can be used as test parameters
static std::vector<nGraphFunctionWithName> getAnyTypeOnlyFunctions();
static std::vector<nGraphFunctionWithName> getNumericTypeOnlyFunctions();
static std::vector<nGraphFunctionWithName> getNumericAnyTypeFunctions();
static std::vector<nGraphFunctionWithName> getFloatingPointOnlyFunctions();
static std::vector<nGraphFunctionWithName> getStandardFunctions();
};

View File

@ -60,18 +60,15 @@ static std::shared_ptr<ov::Model> simple_function_relu(ov::element::Type type, s
return func;
}
std::vector<ovModelWithName> CompileModelCacheTestBase::getStandardFunctions() {
// Wrapper of most part of available builder functions
using ovModelIS = std::function<std::shared_ptr<ov::Model>(std::vector<size_t> inputShape,
ov::element::Type_t type)>;
auto inputShapeWrapper = [](ovModelIS fun, std::vector<size_t> inputShape) {
return [fun, inputShape](ngraph::element::Type type, std::size_t batchSize) {
auto shape = inputShape;
shape[0] = batchSize;
return fun(shape, type);
};
ovModelGenerator CompileModelCacheTestBase::inputShapeWrapper(ovModelIS fun, std::vector<size_t> inputShape) {
return [fun, inputShape](ngraph::element::Type type, std::size_t batchSize) {
auto shape = inputShape;
shape[0] = batchSize;
return fun(shape, type);
};
}
std::vector<ovModelWithName> CompileModelCacheTestBase::getNumericTypeOnlyFunctions() {
std::vector<ovModelWithName> res;
res.push_back(ovModelWithName { simple_function_multiply, "SimpleFunctionMultiply"});
res.push_back(ovModelWithName { simple_function_relu, "SimpleFunctionRelu"});
@ -84,9 +81,6 @@ std::vector<ovModelWithName> CompileModelCacheTestBase::getStandardFunctions() {
res.push_back(ovModelWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeKSOFunction, {1, 4, 20, 20}),
"KSOFunction"});
res.push_back(ovModelWithName { [](ngraph::element::Type type, size_t batchSize) {
return ngraph::builder::subgraph::makeTIwithLSTMcell(type, batchSize);
}, "TIwithLSTMcell1"});
res.push_back(ovModelWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeSingleConv, {1, 3, 24, 24}),
"SingleConv"});
@ -108,16 +102,46 @@ std::vector<ovModelWithName> CompileModelCacheTestBase::getStandardFunctions() {
res.push_back(ovModelWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeConvBias, {1, 3, 24, 24}),
"ConvBias"});
res.push_back(ovModelWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeReadConcatSplitAssign, {1, 1, 2, 4}),
"ReadConcatSplitAssign"});
res.push_back(ovModelWithName{
inputShapeWrapper(ngraph::builder::subgraph::makeMatMulBias, {1, 3, 24, 24}),
"MatMulBias" });
return res;
}
std::vector<ovModelWithName> CompileModelCacheTestBase::getAnyTypeOnlyFunctions() {
std::vector<ovModelWithName> res;
res.push_back(ovModelWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeReadConcatSplitAssign, {1, 1, 2, 4}),
"ReadConcatSplitAssign"});
return res;
}
std::vector<ovModelWithName> CompileModelCacheTestBase::getFloatingPointOnlyFunctions() {
std::vector<ovModelWithName> res;
res.push_back(ovModelWithName { [](ngraph::element::Type type, size_t batchSize) {
return ngraph::builder::subgraph::makeTIwithLSTMcell(type, batchSize);
}, "TIwithLSTMcell1"});
return res;
}
std::vector<ovModelWithName> CompileModelCacheTestBase::getNumericAnyTypeFunctions() {
std::vector<ovModelWithName> funcs = CompileModelCacheTestBase::getAnyTypeOnlyFunctions();
std::vector<ovModelWithName> numericType = CompileModelCacheTestBase::getNumericTypeOnlyFunctions();
funcs.insert(funcs.end(), numericType.begin(), numericType.end());
return funcs;
}
std::vector<ovModelWithName> CompileModelCacheTestBase::getStandardFunctions() {
std::vector<ovModelWithName> funcs = CompileModelCacheTestBase::getAnyTypeOnlyFunctions();
std::vector<ovModelWithName> numericType = CompileModelCacheTestBase::getNumericTypeOnlyFunctions();
funcs.insert(funcs.end(), numericType.begin(), numericType.end());
std::vector<ovModelWithName> floatType = CompileModelCacheTestBase::getFloatingPointOnlyFunctions();
funcs.insert(funcs.end(), floatType.begin(), floatType.end());
return funcs;
}
bool CompileModelCacheTestBase::importExportSupported(ov::Core& core) const {
auto supportedProperties = core.get_property(targetDevice, ov::supported_properties);
if (std::find(supportedProperties.begin(), supportedProperties.end(), ov::device::capabilities) == supportedProperties.end()) {

View File

@ -56,18 +56,15 @@ static std::shared_ptr<ngraph::Function> simple_function_relu(ngraph::element::T
return func;
}
std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getStandardFunctions() {
// Wrapper of most part of available builder functions
using ngraphFunctionIS = std::function<std::shared_ptr<ngraph::Function>(std::vector<size_t> inputShape,
ngraph::element::Type_t type)>;
auto inputShapeWrapper = [](ngraphFunctionIS fun, std::vector<size_t> inputShape) {
return [fun, inputShape](ngraph::element::Type type, std::size_t batchSize) {
auto shape = inputShape;
shape[0] = batchSize;
return fun(shape, type);
};
ngraphFunctionGenerator LoadNetworkCacheTestBase::inputShapeWrapper(ngraphFunctionIS fun, std::vector<size_t> inputShape) {
return [fun, inputShape](ngraph::element::Type type, std::size_t batchSize) {
auto shape = inputShape;
shape[0] = batchSize;
return fun(shape, type);
};
}
std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getNumericTypeOnlyFunctions() {
std::vector<nGraphFunctionWithName> res;
res.push_back(nGraphFunctionWithName { simple_function_multiply, "SimpleFunctionMultiply"});
res.push_back(nGraphFunctionWithName { simple_function_relu, "SimpleFunctionRelu"});
@ -80,9 +77,6 @@ std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getStandardFunctio
res.push_back(nGraphFunctionWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeKSOFunction, {1, 4, 20, 20}),
"KSOFunction"});
res.push_back(nGraphFunctionWithName { [](ngraph::element::Type type, size_t batchSize) {
return ngraph::builder::subgraph::makeTIwithLSTMcell(type, batchSize);
}, "TIwithLSTMcell1"});
res.push_back(nGraphFunctionWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeSingleConv, {1, 3, 24, 24}),
"SingleConv"});
@ -104,16 +98,44 @@ std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getStandardFunctio
res.push_back(nGraphFunctionWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeConvBias, {1, 3, 24, 24}),
"ConvBias"});
res.push_back(nGraphFunctionWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeReadConcatSplitAssign, {1, 1, 2, 4}),
"ReadConcatSplitAssign"});
res.push_back(nGraphFunctionWithName{
inputShapeWrapper(ngraph::builder::subgraph::makeMatMulBias, {1, 3, 24, 24}),
"MatMulBias" });
return res;
}
std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getAnyTypeOnlyFunctions() {
std::vector<nGraphFunctionWithName> res;
return res;
}
std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getFloatingPointOnlyFunctions() {
std::vector<nGraphFunctionWithName> res;
res.push_back(nGraphFunctionWithName { [](ngraph::element::Type type, size_t batchSize) {
return ngraph::builder::subgraph::makeTIwithLSTMcell(type, batchSize);
}, "TIwithLSTMcell1"});
return res;
}
std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getNumericAnyTypeFunctions() {
std::vector<nGraphFunctionWithName> funcs = LoadNetworkCacheTestBase::getAnyTypeOnlyFunctions();
std::vector<nGraphFunctionWithName> numericType = LoadNetworkCacheTestBase::getNumericTypeOnlyFunctions();
funcs.insert(funcs.end(), numericType.begin(), numericType.end());
return funcs;
}
std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getStandardFunctions() {
std::vector<nGraphFunctionWithName> funcs = LoadNetworkCacheTestBase::getAnyTypeOnlyFunctions();
std::vector<nGraphFunctionWithName> numericType = LoadNetworkCacheTestBase::getNumericTypeOnlyFunctions();
funcs.insert(funcs.end(), numericType.begin(), numericType.end());
std::vector<nGraphFunctionWithName> floatType = LoadNetworkCacheTestBase::getFloatingPointOnlyFunctions();
funcs.insert(funcs.end(), floatType.begin(), floatType.end());
return funcs;
}
bool LoadNetworkCacheTestBase::importExportSupported(InferenceEngine::Core& ie) const {
std::vector<std::string> supportedMetricKeys = ie.GetMetric(targetDevice, METRIC_KEY(SUPPORTED_METRICS));
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),