Add dynamism in memory tests (API 2) (#10589)

This commit is contained in:
Valentin Dymchishin 2022-03-28 12:51:53 +03:00 committed by GitHub
parent 76e2f2697f
commit 52937967bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 89 additions and 36 deletions

View File

@ -29,3 +29,4 @@ add_subdirectory(src)
install(DIRECTORY test_runner/ DESTINATION tests/memory_tests/test_runner COMPONENT tests EXCLUDE_FROM_ALL) install(DIRECTORY test_runner/ DESTINATION tests/memory_tests/test_runner COMPONENT tests EXCLUDE_FROM_ALL)
install(DIRECTORY .automation/ DESTINATION tests/memory_tests/test_runner/.automation COMPONENT tests EXCLUDE_FROM_ALL) install(DIRECTORY .automation/ DESTINATION tests/memory_tests/test_runner/.automation COMPONENT tests EXCLUDE_FROM_ALL)
install(DIRECTORY scripts/ DESTINATION tests/memory_tests/scripts COMPONENT tests EXCLUDE_FROM_ALL) install(DIRECTORY scripts/ DESTINATION tests/memory_tests/scripts COMPONENT tests EXCLUDE_FROM_ALL)
install(DIRECTORY ../utils/ DESTINATION tests/utils COMPONENT tests EXCLUDE_FROM_ALL)

View File

@ -22,6 +22,6 @@ public:
MemoryCounter(const std::string &mem_counter_name); MemoryCounter(const std::string &mem_counter_name);
}; };
#define MEMORY_SNAPSHOT(mem_counter_name) MemoryTest::MemoryCounter (#mem_counter_name); #define MEMORY_SNAPSHOT(mem_counter_name) MemoryTest::MemoryCounter mem_counter_name(#mem_counter_name);
} // namespace MemoryTest } // namespace MemoryTest

View File

@ -15,7 +15,7 @@ foreach(test_source ${tests})
get_filename_component(test_name ${test_source} NAME_WE) get_filename_component(test_name ${test_source} NAME_WE)
add_executable(${test_name} ${test_source}) add_executable(${test_name} ${test_source})
target_link_libraries(${test_name} PRIVATE memory_tests_helper tests_shared_lib) target_link_libraries(${test_name} PRIVATE tests_shared_lib memory_tests_helper)
add_dependencies(memory_tests ${test_name}) add_dependencies(memory_tests ${test_name})

View File

@ -15,8 +15,12 @@
* main(). The function should not throw any exceptions and responsible for * main(). The function should not throw any exceptions and responsible for
* handling it by itself. * handling it by itself.
*/ */
int runPipeline(const std::string &model, const std::string &device) { int runPipeline(const std::string &model, const std::string &device,
auto pipeline = [](const std::string &model, const std::string &device) { std::map<std::string, ov::PartialShape> reshapeShapes,
std::map<std::string, std::vector<size_t>> dataShapes) {
auto pipeline = [](const std::string &model, const std::string &device,
std::map<std::string, ov::PartialShape> reshapeShapes,
std::map<std::string, std::vector<size_t>> dataShapes) {
InferenceEngine::Core ie; InferenceEngine::Core ie;
InferenceEngine::CNNNetwork cnnNetwork; InferenceEngine::CNNNetwork cnnNetwork;
InferenceEngine::ExecutableNetwork exeNetwork; InferenceEngine::ExecutableNetwork exeNetwork;
@ -53,7 +57,7 @@ int runPipeline(const std::string &model, const std::string &device) {
}; };
try { try {
pipeline(model, device); pipeline(model, device, reshapeShapes, dataShapes);
} catch (const InferenceEngine::Exception &iex) { } catch (const InferenceEngine::Exception &iex) {
std::cerr std::cerr
<< "Inference Engine pipeline failed with Inference Engine exception:\n" << "Inference Engine pipeline failed with Inference Engine exception:\n"

View File

@ -1,15 +1,15 @@
// Copyright (C) 2018-2022 Intel Corporation // Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <openvino/runtime/core.hpp>
#include <openvino/runtime/infer_request.hpp>
#include <iostream>
#include <fstream> #include <fstream>
#include "common_utils.h" #include "common_utils.h"
#include "reshape_utils.h"
#include "memory_tests_helper/memory_counter.h" #include "memory_tests_helper/memory_counter.h"
#include "memory_tests_helper/utils.h" #include "memory_tests_helper/utils.h"
#include "openvino/runtime/core.hpp"
/** /**
@ -17,43 +17,66 @@
* main(). The function should not throw any exceptions and responsible for * main(). The function should not throw any exceptions and responsible for
* handling it by itself. * handling it by itself.
*/ */
int runPipeline(const std::string &model, const std::string &device) { int runPipeline(const std::string &model, const std::string &device,
auto pipeline = [](const std::string &model, const std::string &device) { std::map<std::string, ov::PartialShape> reshapeShapes,
std::map<std::string, std::vector<size_t>> dataShapes) {
auto pipeline = [](const std::string &model, const std::string &device,
std::map<std::string, ov::PartialShape> reshapeShapes,
std::map<std::string, std::vector<size_t>> dataShapes) {
ov::Core ie; ov::Core ie;
std::shared_ptr<ov::Model> network; std::shared_ptr<ov::Model> cnnNetwork;
ov::CompiledModel compiled_model; ov::CompiledModel exeNetwork;
ov::InferRequest infer_request; ov::InferRequest inferRequest;
std::vector<ov::Output<ov::Node>> defaultInputs;
bool reshape = false;
if (!reshapeShapes.empty()) {
reshape = true;
}
ie.get_versions(device); ie.get_versions(device);
MEMORY_SNAPSHOT(load_plugin); MEMORY_SNAPSHOT(load_plugin);
if (MemoryTest::fileExt(model) == "blob") { if (MemoryTest::fileExt(model) == "blob") {
std::ifstream streamModel{model}; std::ifstream streamModel{model};
compiled_model = ie.import_model(streamModel, device); exeNetwork = ie.import_model(streamModel, device);
MEMORY_SNAPSHOT(import_network); MEMORY_SNAPSHOT(import_network);
} else { } else {
network = ie.read_model(model); cnnNetwork = ie.read_model(model);
MEMORY_SNAPSHOT(read_network); MEMORY_SNAPSHOT(read_network);
compiled_model = ie.compile_model(network, device); if (reshape) {
defaultInputs = getCopyOfDefaultInputs(cnnNetwork->inputs());
cnnNetwork->reshape(reshapeShapes);
MEMORY_SNAPSHOT(reshape);
}
exeNetwork = ie.compile_model(cnnNetwork, device);
MEMORY_SNAPSHOT(load_network); MEMORY_SNAPSHOT(load_network);
} }
MEMORY_SNAPSHOT(create_exenetwork); MEMORY_SNAPSHOT(create_exenetwork);
infer_request = compiled_model.create_infer_request(); inferRequest = exeNetwork.create_infer_request();
auto inputs = network->inputs(); std::vector<ov::Output<const ov::Node>> inputs = exeNetwork.inputs();
fillTensors(infer_request, inputs); if (reshape && dataShapes.empty()) {
MEMORY_SNAPSHOT(fill_inputs) fillTensors(inferRequest, defaultInputs);
} else if (reshape && !dataShapes.empty()) {
fillTensorsWithSpecifiedShape(inferRequest, inputs, dataShapes);
} else {
fillTensors(inferRequest, inputs);
}
MEMORY_SNAPSHOT(fill_inputs);
infer_request.infer(); inferRequest.infer();
MEMORY_SNAPSHOT(first_inference); MEMORY_SNAPSHOT(first_inference);
MEMORY_SNAPSHOT(full_run); MEMORY_SNAPSHOT(full_run);
}; };
try { try {
pipeline(model, device); pipeline(model, device, reshapeShapes, dataShapes);
} catch (const InferenceEngine::Exception &iex) { } catch (const InferenceEngine::Exception &iex) {
std::cerr std::cerr
<< "Inference Engine pipeline failed with Inference Engine exception:\n" << "Inference Engine pipeline failed with Inference Engine exception:\n"

View File

@ -12,4 +12,4 @@ add_subdirectory(${OpenVINO_SOURCE_DIR}/thirdparty/gflags
${CMAKE_CURRENT_BINARY_DIR}/gflags_build ${CMAKE_CURRENT_BINARY_DIR}/gflags_build
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
target_link_libraries(${TARGET_NAME} PUBLIC gflags) target_link_libraries(${TARGET_NAME} PUBLIC gflags tests_shared_lib)

View File

@ -26,6 +26,16 @@ static const char target_device_message[] =
"plugin. " "plugin. "
"The application looks for a suitable plugin for the specified device."; "The application looks for a suitable plugin for the specified device.";
/// @brief message for shapes argument
static const char reshape_shapes_message[] =
"Not required. Use this key to run memory tests with reshape. \n"
"Example: 'input*1..2 3 100 100'. Use '&' delimiter for several inputs. Example: 'input1*1..2 100&input2*1..2 100' ";
/// @brief message for shapes argument
static const char data_shapes_message[] =
"Not required. Use this key to run memory tests with reshape. Used with 'reshape_shapes' arg. \n"
"Only static shapes for data. Example: 'input*1 3 100 100'. Use '&' delimiter for several inputs. Example: 'input1*1 100&input2*1 100' ";
/// @brief message for statistics path argument /// @brief message for statistics path argument
static const char statistics_path_message[] = static const char statistics_path_message[] =
"Required. Path to a file to write statistics."; "Required. Path to a file to write statistics.";
@ -44,6 +54,14 @@ DEFINE_string(m, "", model_message);
/// It is a required parameter /// It is a required parameter
DEFINE_string(d, "", target_device_message); DEFINE_string(d, "", target_device_message);
/// @brief Define parameter for set shapes to reshape function <br>
/// It is a non-required parameter
DEFINE_string(reshape_shapes, "", reshape_shapes_message);
/// @brief Define parameter for set shapes of the network data <br>
/// It is a non-required parameter
DEFINE_string(data_shapes, "", data_shapes_message);
/// @brief Define parameter for set path to a file to write statistics <br> /// @brief Define parameter for set path to a file to write statistics <br>
/// It is a required parameter /// It is a required parameter
DEFINE_string(s, "", statistics_path_message); DEFINE_string(s, "", statistics_path_message);
@ -53,13 +71,13 @@ DEFINE_string(s, "", statistics_path_message);
*/ */
static void showUsage() { static void showUsage() {
std::cout << std::endl; std::cout << std::endl;
std::cout << "TimeTests [OPTION]" << std::endl; std::cout << "MemoryInfer [OPTION]" << std::endl;
std::cout << "Options:" << std::endl; std::cout << "Options:" << std::endl;
std::cout << std::endl; std::cout << std::endl;
std::cout << " -h, --help " << help_message << std::endl; std::cout << " -h, --help " << help_message << std::endl;
std::cout << " -m \"<path>\" " << model_message << std::endl; std::cout << " -m \"<path>\" " << model_message << std::endl;
std::cout << " -d \"<device>\" " << target_device_message std::cout << " -d \"<device>\" " << target_device_message << std::endl;
<< std::endl; std::cout << " -s \"<path>\" " << statistics_path_message << std::endl;
std::cout << " -s \"<path>\" " << statistics_path_message std::cout << " -reshape_shapes " << reshape_shapes_message << std::endl;
<< std::endl; std::cout << " -data_shapes " << data_shapes_message << std::endl;
} }

View File

@ -4,11 +4,14 @@
#include "cli.h" #include "cli.h"
#include "statistics_writer.h" #include "statistics_writer.h"
#include "reshape_utils.h"
#include "memory_tests_helper/memory_counter.h" #include "memory_tests_helper/memory_counter.h"
#include <iostream> #include <iostream>
int runPipeline(const std::string &model, const std::string &device); int runPipeline(const std::string &model, const std::string &device,
std::map<std::string, ov::PartialShape> reshapeShapes,
std::map<std::string, std::vector<size_t>> dataShapes);
/** /**
* @brief Parses command line and check required arguments * @brief Parses command line and check required arguments
@ -38,10 +41,11 @@ bool parseAndCheckCommandLine(int argc, char **argv) {
/** /**
* @brief Function calls `runPipeline` with mandatory memory values tracking of full run * @brief Function calls `runPipeline` with mandatory memory values tracking of full run
*/ */
int _runPipeline() { int _runPipeline(std::map<std::string, ov::PartialShape> dynamicShapes,
auto status = runPipeline(FLAGS_m, FLAGS_d); std::map<std::string, std::vector<size_t>> staticShapes) {
MEMORY_SNAPSHOT(after_objects_release); auto status = runPipeline(FLAGS_m, FLAGS_d, dynamicShapes, staticShapes);
return status; MEMORY_SNAPSHOT(after_objects_release);
return status;
} }
/** /**
@ -51,7 +55,10 @@ int main(int argc, char **argv) {
if (!parseAndCheckCommandLine(argc, argv)) if (!parseAndCheckCommandLine(argc, argv))
return -1; return -1;
auto status = _runPipeline(); auto dynamicShapes = parseReshapeShapes(FLAGS_reshape_shapes);
auto staticShapes = parseDataShapes(FLAGS_data_shapes);
auto status = _runPipeline(dynamicShapes, staticShapes);
StatisticsWriter::Instance().setFile(FLAGS_s); StatisticsWriter::Instance().setFile(FLAGS_s);
StatisticsWriter::Instance().write(); StatisticsWriter::Instance().write();
return status; return status;