Small refactoring in TEMPLATE plugin (#5398)

* Small refactoring in TEMPLATE plugin

* Fixed compilation on Windows

* Fixed code style
This commit is contained in:
Ilya Lavrenov 2021-04-27 18:52:45 +03:00 committed by GitHub
parent 689f8aedb6
commit 6c83e0f8a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 70 deletions

View File

@ -14,6 +14,7 @@
namespace TemplatePlugin {
// forward declaration
class Plugin;
/**

View File

@ -8,17 +8,6 @@
#include <string>
#include <map>
#include <ie_blob.h>
#include <description_buffer.hpp>
#include <debug.h>
#include <ie_layouts.h>
#include <threading/ie_executor_manager.hpp>
#include <blob_transform.hpp>
#include <ie_parallel.hpp>
#include <ie_memcpy.h>
#include <precision_utils.h>
#include "template/template_config.hpp"
#include "template_infer_request.hpp"
#include "template_executable_network.hpp"
#include "template_plugin.hpp"

View File

@ -7,22 +7,21 @@
#include <map>
#include <string>
#include <vector>
#include <array>
#include <memory>
#include <unordered_map>
#include <chrono>
#include <ie_common.h>
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <threading/ie_itask_executor.hpp>
#include <openvino/itt.hpp>
#include <ie_input_info.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <ngraph/runtime/tensor.hpp>
#include <executable.hpp>
#include "template_config.hpp"
namespace TemplatePlugin {
// forward declaration
class ExecutableNetwork;
// ! [infer_request:header]

View File

@ -81,50 +81,19 @@ InferenceEngine::ExecutableNetworkInternal::Ptr Plugin::LoadExeNetworkImpl(const
const ConfigMap &config) {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::LoadExeNetworkImpl");
auto cfg = Configuration{ config, _cfg };
InferenceEngine::InputsDataMap networkInputs = network.getInputsInfo();
InferenceEngine::OutputsDataMap networkOutputs = network.getOutputsInfo();
// TODO: check with precisions supported by Template device
for (auto networkOutput : networkOutputs) {
auto output_precision = networkOutput.second->getPrecision();
if (output_precision != InferenceEngine::Precision::FP32 &&
output_precision != InferenceEngine::Precision::FP16 &&
output_precision != InferenceEngine::Precision::U8) {
IE_THROW() << "Template device supports only U8, FP16 and FP32 output precision.";
}
}
for (auto networkInput : networkInputs) {
auto input_precision = networkInput.second->getTensorDesc().getPrecision();
if (input_precision != InferenceEngine::Precision::FP32 &&
input_precision != InferenceEngine::Precision::FP16 &&
input_precision != InferenceEngine::Precision::I16 &&
input_precision != InferenceEngine::Precision::U8) {
IE_THROW() << "Input image format " << input_precision << " is not supported yet.\n"
<< "Supported formats are: FP32, FP16, I16 and U8.";
}
}
auto function = network.getFunction();
if (function == nullptr) {
IE_THROW() << "TEMPLATE plugin can compile only IR v10 networks";
}
return std::make_shared<ExecutableNetwork>(function, cfg, std::static_pointer_cast<Plugin>(shared_from_this()));
auto fullConfig = Configuration{ config, _cfg };
return std::make_shared<ExecutableNetwork>(network.getFunction(), fullConfig,
std::static_pointer_cast<Plugin>(shared_from_this()));
}
// ! [plugin:load_exe_network_impl]
// ! [plugin:import_network_impl]
InferenceEngine::ExecutableNetworkInternal::Ptr
Plugin::ImportNetworkImpl(std::istream& model, const std::map<std::string, std::string>& config) {
Plugin::ImportNetworkImpl(std::istream& modelStream, const std::map<std::string, std::string>& config) {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::ImportNetworkImpl");
Configuration cfg(config);
return std::make_shared<ExecutableNetwork>(model, cfg,
auto fullConfig = Configuration{ config, _cfg };
return std::make_shared<ExecutableNetwork>(modelStream, fullConfig,
std::static_pointer_cast<Plugin>(shared_from_this()));
}
// ! [plugin:import_network_impl]
@ -133,13 +102,8 @@ Plugin::ImportNetworkImpl(std::istream& model, const std::map<std::string, std::
InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::CNNNetwork &network, const ConfigMap& config) const {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::QueryNetwork");
InferenceEngine::QueryNetworkResult res;
Configuration cfg{config, _cfg, false};
Configuration fullConfig{config, _cfg, false};
auto function = network.getFunction();
if (function == nullptr) {
IE_THROW() << "Template Plugin supports only ngraph cnn network representation";
}
// 1. First of all we should store initial input operation set
std::unordered_set<std::string> originalOps;
@ -207,6 +171,7 @@ InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::
}
// 7. Produce the result
InferenceEngine::QueryNetworkResult res;
for (auto&& layerName : supported) {
res.supportedLayersMap.emplace(layerName, GetName());
}

View File

@ -4,11 +4,9 @@
#include "test_utils_api_impl.hpp"
#include <common_test_utils/ngraph_test_utils.hpp>
#include <string>
#include <common_test_utils/ngraph_test_utils.hpp>
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs,
InferenceEnginePython::IENetwork rhs) {
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs, InferenceEnginePython::IENetwork rhs) {
return compare_functions(lhs.actual->getFunction(), rhs.actual->getFunction(), true, true, false, true);
}

View File

@ -286,6 +286,11 @@ struct QueryNetworkResult {
*/
using ConstOutputsDataMap = std::map<std::string, CDataPtr>;
/**
* @brief A collection that contains string as key, and Data smart pointer as value
*/
using OutputsDataMap = std::map<std::string, DataPtr>;
namespace details {
struct INFERENCE_ENGINE_DEPRECATED("Use InferRequest::Exception")
INFERENCE_ENGINE_API_CLASS(InferenceEngineException) : public std::runtime_error {

View File

@ -34,11 +34,6 @@ class Function;
namespace InferenceEngine {
/**
* @brief A collection that contains string as key, and Data smart pointer as value
*/
using OutputsDataMap = std::map<std::string, DataPtr>;
/**
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead
* @interface ICNNNetwork