diff --git a/inference-engine/src/inference_engine/ie_core.cpp b/inference-engine/src/inference_engine/ie_core.cpp index 89e88055637..cf6725c908e 100644 --- a/inference-engine/src/inference_engine/ie_core.cpp +++ b/inference-engine/src/inference_engine/ie_core.cpp @@ -12,6 +12,9 @@ #include #include #include +#include +#include +#include #include #include "ie_plugin_cpp.hpp" @@ -294,6 +297,23 @@ public: QueryNetworkResult res; auto parsed = parseDeviceNameIntoConfig(deviceName, config); GetCPPPluginByName(parsed._deviceName).QueryNetwork(network, parsed._config, res); + if (!network.getFunction()) + return res; + + // WA for constant folded operations (plugins should support all folded ops) + const auto& func = network.getFunction(); + auto specialized_function = ngraph::clone_function(*func); + + ngraph::pass::ConstantFolding().run_on_function(specialized_function); + std::unordered_set operationNames; + for (const auto& op : specialized_function->get_ops()) + operationNames.emplace(op->get_friendly_name()); + + for (const auto& op : func->get_ops()) { + if (operationNames.find(op->get_friendly_name()) != operationNames.end()) + continue; + res.supportedLayersMap[op->get_friendly_name()] = deviceName; + } return res; } diff --git a/inference-engine/tests/functional/plugin/shared/include/behavior/core_integration.hpp b/inference-engine/tests/functional/plugin/shared/include/behavior/core_integration.hpp index cec77e90339..4bbeb9460e6 100644 --- a/inference-engine/tests/functional/plugin/shared/include/behavior/core_integration.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/behavior/core_integration.hpp @@ -86,7 +86,7 @@ public: class IEClassNetworkTest : public ::testing::Test { public: - CNNNetwork actualNetwork, simpleNetwork, multinputNetwork; + CNNNetwork actualNetwork, simpleNetwork, multinputNetwork, ksoNetwork; void SetUp() override { // Generic network @@ -104,6 +104,11 @@ public: auto fnPtr = ngraph::builder::subgraph::make2InputSubtract(); multinputNetwork = InferenceEngine::CNNNetwork(fnPtr); } + // Network with KSO + { + auto fnPtr = ngraph::builder::subgraph::makeKSOFunction(); + ksoNetwork = InferenceEngine::CNNNetwork(fnPtr); + } } void setHeteroNetworkAffinity(const std::string& targetDevice) { const std::map deviceMapping = { @@ -549,6 +554,49 @@ TEST_P(IEClassNetworkTestP, QueryNetworkActualNoThrow) { } } +TEST_P(IEClassNetworkTestP, QueryNetworkWithKSO) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + Core ie; + + try { + auto rres = ie.QueryNetwork(ksoNetwork, deviceName); + auto rl_map = rres.supportedLayersMap; + auto func = ksoNetwork.getFunction(); + for (const auto & op : func->get_ops()) { + if (!rl_map.count(op->get_friendly_name())) { + FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName; + } + } + } catch (const InferenceEngine::details::InferenceEngineException & ex) { + std::string message = ex.what(); + ASSERT_STR_CONTAINS(message, "[NOT_IMPLEMENTED] ngraph::Function is not supported natively"); + } +} + +TEST_P(IEClassNetworkTestP, SetAffinityWithKSO) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + Core ie; + + try { + auto rres = ie.QueryNetwork(ksoNetwork, deviceName); + auto rl_map = rres.supportedLayersMap; + auto func = ksoNetwork.getFunction(); + for (const auto & op : func->get_ops()) { + if (!rl_map.count(op->get_friendly_name())) { + FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName; + } + } + for (const auto & op : ksoNetwork.getFunction()->get_ops()) { + std::string affinity = rl_map[op->get_friendly_name()]; + op->get_rt_info()["affinity"] = std::make_shared>(affinity); + } + ExecutableNetwork exeNetwork = ie.LoadNetwork(ksoNetwork, deviceName); + } catch (const InferenceEngine::details::InferenceEngineException & ex) { + std::string message = ex.what(); + ASSERT_STR_CONTAINS(message, "[NOT_IMPLEMENTED] ngraph::Function is not supported natively"); + } +} + TEST_P(IEClassNetworkTestP, QueryNetworkHeteroActualNoThrow) { SKIP_IF_CURRENT_TEST_IS_DISABLED() Core ie; diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp index d6f002f2218..8064ffb0164 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp @@ -60,6 +60,26 @@ static std::shared_ptr makeSplitConvConcat(std::vector return fnPtr; } +static std::shared_ptr makeKSOFunction(std::vector inputShape = {1, 4, 20, 20}, + InferenceEngine::Precision netPrecision = InferenceEngine::Precision::FP32) { + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShape}); + + auto shapeOf = std::make_shared(params[0]); + auto convert = std::make_shared(shapeOf, ngPrc); + auto newShape = ngraph::builder::makeConstant(ngraph::element::i64, {4}, {1, 4, 1, 1}); + auto reshape = std::make_shared(convert, newShape, false); + auto conv1 = ngraph::builder::makeConvolution(params[0], ngPrc, {3, 3}, {1, 1}, {0, 0}, {0, 0}, {1, 1}, + ngraph::op::PadType::EXPLICIT, 4); + auto relu1 = std::make_shared(conv1); + auto add = std::make_shared(relu1, reshape); + + ngraph::ResultVector results{std::make_shared(add)}; + std::shared_ptr fnPtr = std::make_shared(results, params); + fnPtr->set_friendly_name("KSOFunction"); + return fnPtr; +} + static std::shared_ptr makeSplitMultiConvConcat(std::vector inputShape = {1, 4, 20, 20}) { auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(InferenceEngine::Precision::FP32); auto params = ngraph::builder::makeParams(ngPrc, {inputShape});