Fixed query network for networks with KSO (#2201)
* Added a test to reproduce QueryNetwork with KSO * Fixed QueryNetwork for networks with KSO * Added additional test
This commit is contained in:
parent
baac903cdc
commit
1bae5504ca
@ -12,6 +12,9 @@
|
||||
#include <ie_core.hpp>
|
||||
#include <multi-device/multi_device_config.hpp>
|
||||
#include <ngraph/opsets/opset.hpp>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/graph_util.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
|
||||
#include <cpp_interfaces/exception2status.hpp>
|
||||
#include "ie_plugin_cpp.hpp"
|
||||
@ -294,6 +297,23 @@ public:
|
||||
QueryNetworkResult res;
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||
GetCPPPluginByName(parsed._deviceName).QueryNetwork(network, parsed._config, res);
|
||||
if (!network.getFunction())
|
||||
return res;
|
||||
|
||||
// WA for constant folded operations (plugins should support all folded ops)
|
||||
const auto& func = network.getFunction();
|
||||
auto specialized_function = ngraph::clone_function(*func);
|
||||
|
||||
ngraph::pass::ConstantFolding().run_on_function(specialized_function);
|
||||
std::unordered_set<std::string> operationNames;
|
||||
for (const auto& op : specialized_function->get_ops())
|
||||
operationNames.emplace(op->get_friendly_name());
|
||||
|
||||
for (const auto& op : func->get_ops()) {
|
||||
if (operationNames.find(op->get_friendly_name()) != operationNames.end())
|
||||
continue;
|
||||
res.supportedLayersMap[op->get_friendly_name()] = deviceName;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ public:
|
||||
|
||||
class IEClassNetworkTest : public ::testing::Test {
|
||||
public:
|
||||
CNNNetwork actualNetwork, simpleNetwork, multinputNetwork;
|
||||
CNNNetwork actualNetwork, simpleNetwork, multinputNetwork, ksoNetwork;
|
||||
|
||||
void SetUp() override {
|
||||
// Generic network
|
||||
@ -104,6 +104,11 @@ public:
|
||||
auto fnPtr = ngraph::builder::subgraph::make2InputSubtract();
|
||||
multinputNetwork = InferenceEngine::CNNNetwork(fnPtr);
|
||||
}
|
||||
// Network with KSO
|
||||
{
|
||||
auto fnPtr = ngraph::builder::subgraph::makeKSOFunction();
|
||||
ksoNetwork = InferenceEngine::CNNNetwork(fnPtr);
|
||||
}
|
||||
}
|
||||
void setHeteroNetworkAffinity(const std::string& targetDevice) {
|
||||
const std::map<std::string, std::string> deviceMapping = {
|
||||
@ -549,6 +554,49 @@ TEST_P(IEClassNetworkTestP, QueryNetworkActualNoThrow) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(IEClassNetworkTestP, QueryNetworkWithKSO) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
Core ie;
|
||||
|
||||
try {
|
||||
auto rres = ie.QueryNetwork(ksoNetwork, deviceName);
|
||||
auto rl_map = rres.supportedLayersMap;
|
||||
auto func = ksoNetwork.getFunction();
|
||||
for (const auto & op : func->get_ops()) {
|
||||
if (!rl_map.count(op->get_friendly_name())) {
|
||||
FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName;
|
||||
}
|
||||
}
|
||||
} catch (const InferenceEngine::details::InferenceEngineException & ex) {
|
||||
std::string message = ex.what();
|
||||
ASSERT_STR_CONTAINS(message, "[NOT_IMPLEMENTED] ngraph::Function is not supported natively");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(IEClassNetworkTestP, SetAffinityWithKSO) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
Core ie;
|
||||
|
||||
try {
|
||||
auto rres = ie.QueryNetwork(ksoNetwork, deviceName);
|
||||
auto rl_map = rres.supportedLayersMap;
|
||||
auto func = ksoNetwork.getFunction();
|
||||
for (const auto & op : func->get_ops()) {
|
||||
if (!rl_map.count(op->get_friendly_name())) {
|
||||
FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName;
|
||||
}
|
||||
}
|
||||
for (const auto & op : ksoNetwork.getFunction()->get_ops()) {
|
||||
std::string affinity = rl_map[op->get_friendly_name()];
|
||||
op->get_rt_info()["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>(affinity);
|
||||
}
|
||||
ExecutableNetwork exeNetwork = ie.LoadNetwork(ksoNetwork, deviceName);
|
||||
} catch (const InferenceEngine::details::InferenceEngineException & ex) {
|
||||
std::string message = ex.what();
|
||||
ASSERT_STR_CONTAINS(message, "[NOT_IMPLEMENTED] ngraph::Function is not supported natively");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(IEClassNetworkTestP, QueryNetworkHeteroActualNoThrow) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
Core ie;
|
||||
|
@ -60,6 +60,26 @@ static std::shared_ptr<ngraph::Function> makeSplitConvConcat(std::vector<size_t>
|
||||
return fnPtr;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Function> makeKSOFunction(std::vector<size_t> inputShape = {1, 4, 20, 20},
|
||||
InferenceEngine::Precision netPrecision = InferenceEngine::Precision::FP32) {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
|
||||
auto shapeOf = std::make_shared<ngraph::opset4::ShapeOf>(params[0]);
|
||||
auto convert = std::make_shared<ngraph::opset4::Convert>(shapeOf, ngPrc);
|
||||
auto newShape = ngraph::builder::makeConstant<int64_t>(ngraph::element::i64, {4}, {1, 4, 1, 1});
|
||||
auto reshape = std::make_shared<ngraph::opset4::Reshape>(convert, newShape, false);
|
||||
auto conv1 = ngraph::builder::makeConvolution(params[0], ngPrc, {3, 3}, {1, 1}, {0, 0}, {0, 0}, {1, 1},
|
||||
ngraph::op::PadType::EXPLICIT, 4);
|
||||
auto relu1 = std::make_shared<ngraph::opset4::Relu>(conv1);
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(relu1, reshape);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(add)};
|
||||
std::shared_ptr<ngraph::Function> fnPtr = std::make_shared<ngraph::Function>(results, params);
|
||||
fnPtr->set_friendly_name("KSOFunction");
|
||||
return fnPtr;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Function> makeSplitMultiConvConcat(std::vector<size_t> inputShape = {1, 4, 20, 20}) {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(InferenceEngine::Precision::FP32);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
|
Loading…
Reference in New Issue
Block a user