[GPU] QueryNetwork method correction to work with dynamic shapes (#9462)
This commit is contained in:
parent
2a476f6906
commit
9e41208791
@ -699,10 +699,25 @@ public:
|
||||
opNames.emplace(op->get_friendly_name());
|
||||
|
||||
for (const auto& op : func->get_ops()) {
|
||||
if (opNames.find(op->get_friendly_name()) == opNames.end() ||
|
||||
(!res.supportedLayersMap.count(op->get_friendly_name()) &&
|
||||
std::dynamic_pointer_cast<ngraph::op::Constant>(op)))
|
||||
if (opNames.find(op->get_friendly_name()) == opNames.end()) {
|
||||
res.supportedLayersMap[op->get_friendly_name()] = defDevice;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& op : func->get_ops()) {
|
||||
if (!res.supportedLayersMap.count(op->get_friendly_name()) &&
|
||||
std::dynamic_pointer_cast<ngraph::op::Constant>(op)) {
|
||||
bool are_all_users_supported = true;
|
||||
for (const auto& user : op->output(0).get_target_inputs()) {
|
||||
if (!res.supportedLayersMap.count(user.get_node()->get_friendly_name())) {
|
||||
are_all_users_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (are_all_users_supported) {
|
||||
res.supportedLayersMap[op->get_friendly_name()] = defDevice;
|
||||
}
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -380,44 +380,18 @@ QueryNetworkResult Plugin::QueryNetwork(const CNNNetwork& network,
|
||||
std::unordered_set<std::string> supported;
|
||||
std::unordered_set<std::string> unsupported;
|
||||
|
||||
std::unordered_set<std::string> splitNames;
|
||||
std::unordered_set<std::string> concatNames;
|
||||
std::unordered_set<std::string> constantsNames;
|
||||
std::unordered_set<std::string> depLayerNames;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::Node>> splits;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> concats;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> constants;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> nextLayerDependent;
|
||||
|
||||
auto layerIsSupported = [&](std::shared_ptr<ngraph::Node> node) {
|
||||
if (node->is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
if (ngraph::is_type<const ngraph::op::v0::DetectionOutput>(node) ||
|
||||
ngraph::is_type<const ngraph::op::v0::PriorBox>(node) ||
|
||||
if (ngraph::is_type<const ngraph::op::v0::PriorBox>(node) ||
|
||||
ngraph::is_type<const ngraph::op::v0::PriorBoxClustered>(node) ||
|
||||
ngraph::is_type<const ngraph::op::v0::Proposal>(node)) {
|
||||
return false;
|
||||
}
|
||||
if (ngraph::is_type<const ngraph::op::v1::Split>(node)) {
|
||||
splitNames.emplace(node->get_friendly_name());
|
||||
splits.push_back(node);
|
||||
return false;
|
||||
}
|
||||
if (ngraph::is_type<const ngraph::op::v0::Concat>(node)) {
|
||||
concatNames.emplace(node->get_friendly_name());
|
||||
concats.push_back(node);
|
||||
return false;
|
||||
}
|
||||
if (ngraph::is_type<const ngraph::op::v1::Reshape>(node) ||
|
||||
ngraph::is_type<const ngraph::op::v0::Squeeze>(node) ||
|
||||
ngraph::is_type<const ngraph::op::v0::Unsqueeze>(node) ||
|
||||
ngraph::is_type<const ngraph::op::v1::Transpose>(node)) {
|
||||
depLayerNames.emplace(node->get_friendly_name());
|
||||
nextLayerDependent.push_back(node);
|
||||
return false;
|
||||
}
|
||||
if (ngraph::is_type<const ngraph::op::v0::Constant>(node)) {
|
||||
constantsNames.emplace(node->get_friendly_name());
|
||||
constants.push_back(node);
|
||||
@ -431,10 +405,18 @@ QueryNetworkResult Plugin::QueryNetwork(const CNNNetwork& network,
|
||||
// Get ops after transformations and check if it's supported
|
||||
// Transformations might lead to the situation when single node is merged to multiple operations,
|
||||
// so we mark original op as supported only if all nodes that it was merged into are supported
|
||||
bool wasNodeAlreadyChecked = false;
|
||||
bool isSupported = false;
|
||||
for (auto&& op : ops) {
|
||||
wasNodeAlreadyChecked = false;
|
||||
isSupported = false;
|
||||
for (auto&& fusedLayerName : ngraph::getFusedNamesVector(op)) {
|
||||
if (InferenceEngine::details::contains(originalOpNames, fusedLayerName)) {
|
||||
if (layerIsSupported(op)) {
|
||||
if (!wasNodeAlreadyChecked) {
|
||||
isSupported = layerIsSupported(op);
|
||||
wasNodeAlreadyChecked = true;
|
||||
}
|
||||
if (isSupported) {
|
||||
supported.emplace(fusedLayerName);
|
||||
} else {
|
||||
unsupported.emplace(fusedLayerName);
|
||||
@ -450,77 +432,7 @@ QueryNetworkResult Plugin::QueryNetwork(const CNNNetwork& network,
|
||||
}
|
||||
unsupported.clear();
|
||||
|
||||
// Check set of heuristics to produce more efficient hetero sub-graph. Note: checks order is important.
|
||||
// 1. Split is marked as supported when all output ops can be offloaded to GPU
|
||||
for (const auto & op : splits) {
|
||||
bool is_supported = true;
|
||||
for (size_t i = 0; i < op->get_output_size(); i++) {
|
||||
auto outTensors = op->get_output_target_inputs(i);
|
||||
for (auto& t : outTensors) {
|
||||
auto output = t.get_node();
|
||||
const auto& name = output->get_friendly_name();
|
||||
if (!InferenceEngine::details::contains(supported, name) &&
|
||||
!InferenceEngine::details::contains(depLayerNames, name) &&
|
||||
!InferenceEngine::details::contains(concatNames, name) &&
|
||||
!InferenceEngine::details::contains(splitNames, name)) {
|
||||
is_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_supported) {
|
||||
supported.emplace(op->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Concat is marked as supported when all inputs can be offloaded to GPU
|
||||
for (const auto& op : concats) {
|
||||
bool is_supported = true;
|
||||
for (size_t i = 0; i < op->get_input_size(); i++) {
|
||||
auto input = op->get_input_node_shared_ptr(i);
|
||||
const auto& name = input->get_friendly_name();
|
||||
if (!InferenceEngine::details::contains(supported, name) &&
|
||||
!InferenceEngine::details::contains(depLayerNames, name) &&
|
||||
!InferenceEngine::details::contains(concatNames, name)) {
|
||||
is_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_supported) {
|
||||
supported.emplace(op->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Some layers are marked as supported when all inputs and outputs can be offloaded to GPU
|
||||
for (const auto& op : nextLayerDependent) {
|
||||
bool is_supported = true;
|
||||
// both inputs and output should be GPU to remain on GPU
|
||||
for (size_t i = 0; i < op->get_input_size(); i++) {
|
||||
auto input = op->get_input_node_shared_ptr(i);
|
||||
const auto& name = input->get_friendly_name();
|
||||
// All inputs must be supported or be a constant
|
||||
if (!InferenceEngine::details::contains(supported, name) && !InferenceEngine::details::contains(constantsNames, name)) {
|
||||
is_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < op->get_output_size(); i++) {
|
||||
auto outTensors = op->get_output_target_inputs(i);
|
||||
for (auto& t : outTensors) {
|
||||
auto output = t.get_node();
|
||||
const auto& name = output->get_friendly_name();
|
||||
if (!InferenceEngine::details::contains(supported, name)) {
|
||||
is_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_supported) {
|
||||
supported.emplace(op->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Constants are marked as supported when all outputs can be offloaded to GPU
|
||||
// 1. Constants are marked as supported when all outputs can be offloaded to GPU
|
||||
for (const auto& op : constants) {
|
||||
bool is_supported = true;
|
||||
for (size_t i = 0; i < op->get_output_size(); i++) {
|
||||
@ -558,14 +470,14 @@ QueryNetworkResult Plugin::QueryNetwork(const CNNNetwork& network,
|
||||
}
|
||||
|
||||
if (ngraph::op::is_constant(node) || ngraph::op::is_parameter(node)) {
|
||||
if (!InferenceEngine::details::contains(supported, node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) {
|
||||
supported.erase(node->get_friendly_name());
|
||||
}
|
||||
} else if (ngraph::op::is_output(node)) {
|
||||
if (!InferenceEngine::details::contains(supported, node->input_values().begin()->get_node()->get_friendly_name())) {
|
||||
supported.erase(node->get_friendly_name());
|
||||
}
|
||||
if (!InferenceEngine::details::contains(supported, node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) {
|
||||
supported.erase(node->get_friendly_name());
|
||||
}
|
||||
} else if (ngraph::op::is_output(node)) {
|
||||
if (!InferenceEngine::details::contains(supported, node->input_values().begin()->get_node()->get_friendly_name())) {
|
||||
supported.erase(node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto&& layerName : supported) {
|
||||
|
Loading…
Reference in New Issue
Block a user