conformance tests for BATCH, alos batch size 1 is default for BATCH:DEVICE

This commit is contained in:
myshevts 2021-12-01 15:11:39 +03:00
parent ddbeff3d46
commit f755fc6b69
9 changed files with 52 additions and 17 deletions

View File

@ -1,25 +1,22 @@
# Copyright (C) 2018-2020 Intel Corporation # Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
set (TARGET_NAME "AutoBatchPlugin") set(TARGET_NAME "AutoBatchPlugin")
if(ENABLE_LTO) file(GLOB SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
ie_enable_lto()
endif()
file(GLOB SOURCES file(GLOB HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/*.hpp)
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
)
file(GLOB HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/*.hpp
)
ie_add_plugin(NAME ${TARGET_NAME} ie_add_plugin(NAME ${TARGET_NAME}
DEVICE_NAME "BATCH" DEVICE_NAME "BATCH"
SOURCES ${SOURCES} ${HEADERS} SOURCES ${SOURCES} ${HEADERS}
VERSION_DEFINES_FOR auto_batch.cpp) VERSION_DEFINES_FOR auto_batch.cpp)
target_link_libraries(${TARGET_NAME} PRIVATE inference_engine inference_engine_legacy) target_link_libraries(${TARGET_NAME} PRIVATE ngraph inference_engine_transformations inference_engine_legacy)
set_ie_threading_interface_for(${TARGET_NAME}) set_ie_threading_interface_for(${TARGET_NAME})
ie_add_api_validator_post_build_step(TARGET ${TARGET_NAME})
set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO})

View File

@ -447,7 +447,7 @@ DeviceInformation AutoBatchInferencePlugin::ParseMetaDevice(const std::string& d
auto closingBracket = d.find_first_of(')', openingBracket); auto closingBracket = d.find_first_of(')', openingBracket);
auto deviceName = d.substr(0, openingBracket); auto deviceName = d.substr(0, openingBracket);
int batch = -1; int batch = 1;
if (closingBracket != std::string::npos && openingBracket < closingBracket) { if (closingBracket != std::string::npos && openingBracket < closingBracket) {
batch = std::stol(d.substr(openingBracket + 1, closingBracket - 1)); batch = std::stol(d.substr(openingBracket + 1, closingBracket - 1));

View File

@ -24,6 +24,7 @@ inline const std::string getPluginLibNameByDevice(const std::string& deviceName)
{ "GNA", "GNAPlugin" }, { "GNA", "GNAPlugin" },
{ "GPU", "clDNNPlugin" }, { "GPU", "clDNNPlugin" },
{ "HETERO", "ov_hetero_plugin" }, { "HETERO", "ov_hetero_plugin" },
{ "BATCH", "AutoBatchPlugin" },
{ "MULTI", "MultiDevicePlugin" }, { "MULTI", "MultiDevicePlugin" },
{ "MYRIAD", "myriadPlugin" }, { "MYRIAD", "myriadPlugin" },
{ "TEMPLATE", "templatePlugin" }, { "TEMPLATE", "templatePlugin" },
@ -42,6 +43,11 @@ inline const std::pair<std::string, std::string> generateDefaultHeteroConfig() {
return { "TARGET_FALLBACK" , ConformanceTests::targetDevice }; return { "TARGET_FALLBACK" , ConformanceTests::targetDevice };
} }
inline const std::pair<std::string, std::string> generateDefaultBatchConfig() {
// auto-batching with batch 1 (no real batching in fact, but full machinery is in action)
return { CONFIG_KEY(AUTO_BATCH) , std::string(ConformanceTests::targetDevice)};
}
inline const std::vector<std::map<std::string, std::string>> generateConfigs(const std::string& targetDevice, inline const std::vector<std::map<std::string, std::string>> generateConfigs(const std::string& targetDevice,
const std::vector<std::map<std::string, std::string>>& config = {}) { const std::vector<std::map<std::string, std::string>>& config = {}) {
std::pair<std::string, std::string> defaultConfig; std::pair<std::string, std::string> defaultConfig;
@ -49,6 +55,8 @@ inline const std::vector<std::map<std::string, std::string>> generateConfigs(con
defaultConfig = generateDefaultMultiConfig(); defaultConfig = generateDefaultMultiConfig();
} else if (targetDevice == std::string(CommonTestUtils::DEVICE_HETERO)) { } else if (targetDevice == std::string(CommonTestUtils::DEVICE_HETERO)) {
defaultConfig = generateDefaultHeteroConfig(); defaultConfig = generateDefaultHeteroConfig();
} else if (targetDevice == std::string(CommonTestUtils::DEVICE_BATCH)) {
defaultConfig = generateDefaultBatchConfig();
} else { } else {
throw std::runtime_error("Incorrect target device: " + targetDevice); throw std::runtime_error("Incorrect target device: " + targetDevice);
} }
@ -70,7 +78,8 @@ inline const std::string generateComplexDeviceName(const std::string& deviceName
inline const std::vector<std::string> returnAllPossibleDeviceCombination() { inline const std::vector<std::string> returnAllPossibleDeviceCombination() {
std::vector<std::string> res{ConformanceTests::targetDevice}; std::vector<std::string> res{ConformanceTests::targetDevice};
std::vector<std::string> devices{CommonTestUtils::DEVICE_HETERO, CommonTestUtils::DEVICE_AUTO, CommonTestUtils::DEVICE_MULTI}; std::vector<std::string> devices{CommonTestUtils::DEVICE_HETERO, CommonTestUtils::DEVICE_AUTO,
CommonTestUtils::DEVICE_BATCH, CommonTestUtils::DEVICE_MULTI};
for (const auto& device : devices) { for (const auto& device : devices) {
res.emplace_back(generateComplexDeviceName(device)); res.emplace_back(generateComplexDeviceName(device));
} }

View File

@ -33,4 +33,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestCallbackTests,
::testing::Values(CommonTestUtils::DEVICE_HETERO), ::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))), ::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))),
InferRequestCallbackTests::getTestCaseName); InferRequestCallbackTests::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Batch_BehaviorTests, InferRequestCallbackTests,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_BATCH),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_BATCH))),
InferRequestCallbackTests::getTestCaseName);
} // namespace } // namespace

View File

@ -36,4 +36,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestIOBBlobTest,
::testing::Values(CommonTestUtils::DEVICE_HETERO), ::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))), ::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))),
InferRequestIOBBlobTest::getTestCaseName); InferRequestIOBBlobTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Batch_BehaviorTests, InferRequestIOBBlobTest,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_BATCH),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_BATCH))),
InferRequestIOBBlobTest::getTestCaseName);
} // namespace } // namespace

View File

@ -38,4 +38,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestMultithreadingT
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))), ::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))),
InferRequestMultithreadingTests::getTestCaseName); InferRequestMultithreadingTests::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Batch_BehaviorTests, InferRequestMultithreadingTests,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_BATCH),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_BATCH))),
InferRequestMultithreadingTests::getTestCaseName);
} // namespace } // namespace

View File

@ -46,4 +46,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_Behavior_Hetero, InferRequestSetBlobByType,
::testing::Values(CommonTestUtils::DEVICE_HETERO), ::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))), ::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))),
InferRequestSetBlobByType::getTestCaseName); InferRequestSetBlobByType::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Behavior_Batch, InferRequestSetBlobByType,
::testing::Combine(::testing::ValuesIn(setBlobTypes),
::testing::Values(CommonTestUtils::DEVICE_BATCH),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_BATCH))),
InferRequestSetBlobByType::getTestCaseName);
} // namespace } // namespace

View File

@ -37,4 +37,9 @@ INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestWaitTests,
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))), ::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_HETERO))),
InferRequestWaitTests::getTestCaseName); InferRequestWaitTests::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Batch_BehaviorTests, InferRequestWaitTests,
::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_BATCH),
::testing::ValuesIn(generateConfigs(CommonTestUtils::DEVICE_BATCH))),
InferRequestWaitTests::getTestCaseName);
} // namespace } // namespace

View File

@ -32,8 +32,8 @@ public:
fn_ptr = ngraph::builder::subgraph::makeSplitMultiConvConcat(); fn_ptr = ngraph::builder::subgraph::makeSplitMultiConvConcat();
deviceName = CommonTestUtils::DEVICE_GPU; deviceName = CommonTestUtils::DEVICE_GPU;
auto with_auto_batching = this->GetParam(); auto with_auto_batching = this->GetParam();
if (with_auto_batching) { // BATCH:GPU(1) if (with_auto_batching) { // BATCH:GPU
deviceName = std::string(CommonTestUtils::DEVICE_BATCH) + ":" + deviceName + "(1)"; deviceName = std::string(CommonTestUtils::DEVICE_BATCH) + ":" + deviceName;
config = {{CONFIG_KEY(ALLOW_AUTO_BATCHING), CONFIG_VALUE(YES)}}; config = {{CONFIG_KEY(ALLOW_AUTO_BATCHING), CONFIG_VALUE(YES)}};
} }
} }