[GPU] Fix batchability check of MAX_BATCH_SIZE (#10660)
* Fix batchability check of MAX_BATCH_SIZE * Applied review comment
This commit is contained in:
parent
2dbd60c1ae
commit
13b6a3d86e
@ -26,9 +26,15 @@
|
||||
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"
|
||||
#include "ie_icore.hpp"
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/common_optimizations/dimension_tracking.hpp"
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include "openvino/pass/serialize.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <openvino/util/common_util.hpp>
|
||||
|
||||
#include "intel_gpu/runtime/device_query.hpp"
|
||||
@ -968,43 +974,61 @@ Parameter Plugin::GetMetric(const std::string& name, const std::map<std::string,
|
||||
auto cloned_network = InferenceEngine::details::cloneNetwork(network);
|
||||
auto inputs_info = cloned_network.getInputsInfo();
|
||||
ICNNNetwork::InputShapes new_shapes;
|
||||
//std::map<std::string, SizeVector>;
|
||||
bool batch_detected = false;
|
||||
for (auto& info : inputs_info) {
|
||||
if (!info.second)
|
||||
continue;
|
||||
InferenceEngine::Layout layout = info.second->getLayout();
|
||||
auto data = info.second->getInputData();
|
||||
if (!data)
|
||||
continue;
|
||||
std::string name = info.second->getInputData()->getName();
|
||||
auto shape = data->getTensorDesc().getDims();
|
||||
if (layout == InferenceEngine::Layout::NCHW ||
|
||||
layout == InferenceEngine::Layout::NHWC ||
|
||||
layout == InferenceEngine::Layout::NCDHW ||
|
||||
layout == InferenceEngine::Layout::NDHWC ||
|
||||
layout == InferenceEngine::Layout::NC) {
|
||||
shape[0] = base_batch_size;
|
||||
batch_detected = true;
|
||||
} else if (layout == InferenceEngine::Layout::CN) {
|
||||
shape[1] = base_batch_size;
|
||||
batch_detected = true;
|
||||
}
|
||||
new_shapes[name] = shape;
|
||||
}
|
||||
|
||||
try {
|
||||
if (batch_detected) { // reshape only for batched layout
|
||||
cloned_network.reshape(new_shapes);
|
||||
GPU_DEBUG_IF(debug_config->verbose >= 1) {
|
||||
GPU_DEBUG_COUT << "[GPU_MAX_BATCH_SIZE] Reshaped base batch size to " << base_batch_size << std::endl;
|
||||
std::set<std::pair<std::string, size_t>> batched_inputs;
|
||||
|
||||
auto function = InferenceEngine::details::cloneNetwork(cloned_network).getFunction();
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::FindBatch>(true, false);
|
||||
m.run_passes(function);
|
||||
const auto& params = function->get_parameters();
|
||||
for (size_t input_id = 0; input_id < params.size(); input_id++) {
|
||||
const auto& input = params[input_id];
|
||||
const auto& shape = input->get_partial_shape();
|
||||
// currently no plugin support batched execution for dynamic networks
|
||||
if (shape.is_dynamic()) {
|
||||
GPU_DEBUG_IF(debug_config->verbose >= 2) {
|
||||
GPU_DEBUG_COUT << "[MAX_BATCH_SIZE] does not support dynamic networks" << std::endl;
|
||||
}
|
||||
return decltype(ov::max_batch_size)::value_type {static_cast<uint32_t>(max_batch_size)};
|
||||
}
|
||||
} else {
|
||||
base_batch_size = 1;
|
||||
GPU_DEBUG_IF(debug_config->verbose >= 1) {
|
||||
GPU_DEBUG_COUT << "[GPU_MAX_BATCH_SIZE] Batch dimension is not used in inputs." << std::endl;
|
||||
|
||||
if (shape.size()) {
|
||||
for (size_t s = 0; s < shape.size(); s++) {
|
||||
if (ov::DimensionTracker::get_label(shape[s])) {
|
||||
// batched dim for the input
|
||||
auto batched_input_id = ngraph::op::util::get_ie_output_name(params[input_id]->output(0));
|
||||
GPU_DEBUG_IF(debug_config->verbose >= 2) {
|
||||
GPU_DEBUG_COUT << "[MAX_BATCH_SIZE] detected batched input " << batched_input_id
|
||||
<< "[" << s << "]" << std::endl;
|
||||
}
|
||||
batched_inputs.insert(std::make_pair(batched_input_id, s));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!batched_inputs.size()) {
|
||||
GPU_DEBUG_IF(debug_config->verbose >= 2) {
|
||||
GPU_DEBUG_COUT << "[MAX_BATCH_SIZE] MAX_BATCH_SIZE supports only networks with inputs/outputs featuring batched dim." << std::endl;
|
||||
}
|
||||
return decltype(ov::max_batch_size)::value_type {static_cast<uint32_t>(max_batch_size)};
|
||||
}
|
||||
|
||||
try {
|
||||
ICNNNetwork::InputShapes shapes = cloned_network.getInputShapes();
|
||||
for (const auto& input : batched_inputs)
|
||||
shapes[input.first][input.second] = base_batch_size;
|
||||
cloned_network.reshape(shapes);
|
||||
} catch (...) {
|
||||
GPU_DEBUG_IF(debug_config->verbose >= 1) {
|
||||
GPU_DEBUG_COUT << "[MAX_BATCH_SIZE] Error at reshape to " << base_batch_size << std::endl;
|
||||
}
|
||||
return decltype(ov::max_batch_size)::value_type {static_cast<uint32_t>(max_batch_size)};
|
||||
}
|
||||
|
||||
auto nGraphFunc = cloned_network.getFunction();
|
||||
TransformationsPipeline transformations(config, device_info);
|
||||
transformations.apply(nGraphFunc);
|
||||
|
Loading…
Reference in New Issue
Block a user