checking the network batch-ability (internal helper func on top of bat… (#10446)
* checking the network batchability (internal helper func on top of batch tracking) before doing hetero * more general logic with respect to batch-ability of the network * a dynamism check that I've owed from the PR-10560 * using the DO-detached mechanism for early hetero exit, also fixed this flag in the Batching plugin (although minor, as the DO is removed by HETERO) * adding the dimension tracking logic depending on whether implicitly/expicitly the auto-batching is enabled * changed the DetectionOutput affinity markup to go over results, also accomodate Convert, so only 2 subgraphs are made by the HETERO
This commit is contained in:
79
src/inference/src/check_network_batchable.cpp
Normal file
79
src/inference/src/check_network_batchable.cpp
Normal file
@@ -0,0 +1,79 @@
|
||||
#include "check_network_batchable.hpp"
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "ie_ngraph_utils.hpp"
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
#include "openvino/op/detection_output.hpp"
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/common_optimizations/dimension_tracking.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
|
||||
NetworkBatchAbility isNetworkBatchable(const CNNNetwork& orig_network,
|
||||
const std::string& deviceNameWithoutBatch,
|
||||
bool strictly_track_dims) {
|
||||
CNNNetwork clonedNetwork(cloneNetwork(orig_network));
|
||||
auto function = clonedNetwork.getFunction();
|
||||
// find the batch dim
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::FindBatch>(true, strictly_track_dims);
|
||||
m.run_passes(function);
|
||||
bool any_batched_inputs = false;
|
||||
// do not reshape/re-batch originally batched networks and when there are no inputs with the N* layouts
|
||||
// input(s) should have the batch dim as the first dim or none (current limitation of the auto-batching impl)
|
||||
const auto& params = function->get_parameters();
|
||||
for (size_t input_id = 0; input_id < params.size(); input_id++) {
|
||||
const auto& input = params[input_id];
|
||||
const auto& shape = input->get_partial_shape();
|
||||
// currently no plugin support batched execution for dynamic networks
|
||||
if (shape.is_dynamic())
|
||||
return NetworkBatchAbility::NO;
|
||||
// check the batch dim: either 0th (and the original batch size of 1) or none
|
||||
if (shape.size() && ov::DimensionTracker::get_label(shape[0])) {
|
||||
const auto& static_shape = input->get_shape();
|
||||
if (static_shape[0] != 1)
|
||||
return NetworkBatchAbility::NO;
|
||||
else
|
||||
any_batched_inputs = true;
|
||||
} else {
|
||||
// if the 0-th dim is not for the batch, then we support only the case when NONE dimension is batch
|
||||
for (size_t s = 1; s < shape.size(); s++)
|
||||
if (ov::DimensionTracker::get_label(shape[s]))
|
||||
return NetworkBatchAbility::NO;
|
||||
}
|
||||
}
|
||||
if (!any_batched_inputs)
|
||||
return NetworkBatchAbility::NO;
|
||||
|
||||
for (auto&& node : orig_network.getFunction()->get_ops())
|
||||
node->get_rt_info()["affinity"] = "BATCH"; // default affinity (ignored if HETERO is not triggered)
|
||||
// have to execute the DetectionOutput separately (without batching)
|
||||
// as this layer does mix-in the values from the different inputs (batch id)
|
||||
bool bDetectionOutput = false;
|
||||
for (auto& result_node : orig_network.getFunction()->get_results()) {
|
||||
auto do_node = result_node->input_value(0).get_node_shared_ptr();
|
||||
std::shared_ptr<ov::Node> convert_node;
|
||||
if (ov::is_type<ov::opset1::Convert>(do_node)) { // cases with do->convert->result
|
||||
convert_node = do_node;
|
||||
do_node = convert_node->get_input_node_shared_ptr(0);
|
||||
}
|
||||
// the code below doesn't need to separate the versions (opsets) of the DetectionOutput
|
||||
// so base class check is enough
|
||||
auto detectionOutputBase = std::dynamic_pointer_cast<ov::op::util::DetectionOutputBase>(do_node);
|
||||
if (detectionOutputBase) {
|
||||
result_node->get_rt_info()["affinity"] = deviceNameWithoutBatch;
|
||||
do_node->get_rt_info()["affinity"] = deviceNameWithoutBatch;
|
||||
if (convert_node)
|
||||
convert_node->get_rt_info()["affinity"] = deviceNameWithoutBatch;
|
||||
bDetectionOutput = true;
|
||||
}
|
||||
}
|
||||
return bDetectionOutput ? NetworkBatchAbility::WITH_HETERO : NetworkBatchAbility::AS_IS;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
||||
23
src/inference/src/check_network_batchable.hpp
Normal file
23
src/inference/src/check_network_batchable.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
/**
|
||||
* @brief Checks if the input network is batch-able (e.g. no dynamic inputs, inputs has the batch dimension, etc)
|
||||
* @param function A ngraph function to check for automatic-batching applicability
|
||||
* @return An enum value indicating whether the network can be safely batched (with HETERO or as is) or not
|
||||
*/
|
||||
enum NetworkBatchAbility : uint32_t { NO = 0, AS_IS, WITH_HETERO };
|
||||
NetworkBatchAbility isNetworkBatchable(const CNNNetwork& network,
|
||||
const std::string& deviceNoBatch,
|
||||
bool strictly_track_dims);
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "any_copy.hpp"
|
||||
#include "check_network_batchable.hpp"
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
#include "compilation_context.hpp"
|
||||
#include "cpp/ie_cnn_network.h"
|
||||
@@ -557,6 +558,8 @@ public:
|
||||
std::string& deviceName,
|
||||
std::map<std::string, std::string>& config) {
|
||||
std::string deviceNameWithBatchSize, deviceNameWithoutBatch;
|
||||
// fully strict dims tracking by default (Auto-Batching is enabled implicitly)
|
||||
bool strictly_check_dims = true;
|
||||
if (deviceName.find("BATCH") != std::string::npos) {
|
||||
// explicitly enabled Auto-Batching
|
||||
auto pos = deviceName.find_first_of(":");
|
||||
@@ -564,6 +567,9 @@ public:
|
||||
return; // BATCH device is already configured via the config
|
||||
deviceNameWithBatchSize = deviceName.substr(pos + 1);
|
||||
deviceNameWithoutBatch = DeviceIDParser::getBatchDevice(deviceNameWithBatchSize);
|
||||
// when user sets the BATCH device explicitly, we may check the dims less strictly
|
||||
// as the result is being checked by the user
|
||||
strictly_check_dims = false;
|
||||
} else {
|
||||
// check whether the Auto-Batching is disabled explicitly
|
||||
const auto& batch_mode = config.find(ov::hint::allow_auto_batching.name());
|
||||
@@ -594,38 +600,18 @@ public:
|
||||
if (bExclReqsEnabled || (!bTputInPlg && !bTputInLoadCfg))
|
||||
return;
|
||||
}
|
||||
auto function = network.getFunction();
|
||||
// have to execute the DetectionOutput separately (without batching)
|
||||
// as this layer mix-in the values from the different inputs (batch id)
|
||||
bool bDetectionOutput = false;
|
||||
const std::string detectionOutputOpName = ngraph::op::DetectionOutput::get_type_info_static().name;
|
||||
const std::string resultOpName = ngraph::op::Result::get_type_info_static().name;
|
||||
for (auto&& node : function->get_ops()) {
|
||||
auto isDetectionOutputParent = [&detectionOutputOpName](decltype(node)& nd) {
|
||||
for (size_t n = 0; n < nd->get_input_size(); n++) {
|
||||
// the code below doesn't need to separate the versions (opsets) of the DetectionOutput
|
||||
// so type_info name check is enough
|
||||
// (if in a future there will be a new ver that doesn't mix the batch, this will be new op)
|
||||
if (detectionOutputOpName == nd->get_input_node_ptr(n)->get_type_info().name)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if ((detectionOutputOpName == node->get_type_info().name) ||
|
||||
((resultOpName == node->get_type_info().name) && isDetectionOutputParent(node))) {
|
||||
node->get_rt_info()["affinity"] = deviceNameWithoutBatch;
|
||||
bDetectionOutput = true;
|
||||
} else {
|
||||
node->get_rt_info()["affinity"] = "BATCH";
|
||||
}
|
||||
}
|
||||
auto batchConfig = deviceNameWithBatchSize.empty() ? deviceNameWithoutBatch : deviceNameWithBatchSize;
|
||||
if (bDetectionOutput) {
|
||||
auto res = InferenceEngine::details::isNetworkBatchable(network, deviceNameWithoutBatch, strictly_check_dims);
|
||||
switch (res) {
|
||||
case InferenceEngine::details::NetworkBatchAbility::NO:
|
||||
return;
|
||||
case InferenceEngine::details::NetworkBatchAbility::AS_IS:
|
||||
deviceName = "BATCH:" + batchConfig;
|
||||
break;
|
||||
case InferenceEngine::details::NetworkBatchAbility::WITH_HETERO:
|
||||
deviceName = "HETERO:BATCH," + deviceNameWithoutBatch;
|
||||
config[CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG)] = batchConfig;
|
||||
} else {
|
||||
deviceName = "BATCH:" + batchConfig;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -841,7 +841,7 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
|
||||
// find the batch dim
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::FindBatch>(true, check_dims);
|
||||
m.register_pass<ov::pass::FindBatch>(false, check_dims);
|
||||
m.run_passes(function);
|
||||
// do not reshape/re-batch originally batched networks and when there are no inputs with the N* layouts
|
||||
// input(s) should have the batch dim as the first dim (current limitation of the auto-batching impl)
|
||||
@@ -871,6 +871,8 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
|
||||
for (size_t output_id = 0; output_id < results.size(); output_id++) {
|
||||
const auto& output = results[output_id];
|
||||
const auto& shape = output->get_output_partial_shape(0);
|
||||
if (shape.is_dynamic())
|
||||
IE_THROW(NotImplemented) << "Auto-batching does not support dynamic networks!";
|
||||
// check the batch dim: either 0th (and the original batch size of 1) or none
|
||||
if (shape.size() && ov::DimensionTracker::get_label(shape[0])) {
|
||||
if (shape[0] != 1)
|
||||
|
||||
Reference in New Issue
Block a user