Removed QueryNetworkResult from new API (#7507)

This commit is contained in:
Ilya Lavrenov 2021-09-15 07:53:47 +03:00 committed by GitHub
parent bdaa44d0be
commit 7654789451
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 20 deletions

View File

@ -22,5 +22,13 @@ namespace runtime {
* @brief This type of map is commonly used to pass set of parameters
*/
using ConfigMap = std::map<std::string, std::string>;
/**
* @brief This type of map is used for result of Core::query_model
* - `key` means operation name
* - `value` means device name supporting this operation
*/
using SupportedOpsMap = std::map<std::string, std::string>;
} // namespace runtime
} // namespace ov

View File

@ -188,11 +188,11 @@ public:
* @param deviceName A name of a device to query
* @param network Network object to query
* @param config Optional map of pairs: (config parameter name, config parameter value)
* @return An object containing a map of pairs a layer name -> a device name supporting this layer.
* @return An object containing a map of pairs a operation name -> a device name supporting this operation.
*/
ie::QueryNetworkResult query_model(const std::shared_ptr<const ov::Function>& network,
const std::string& deviceName,
const ConfigMap& config = {}) const;
SupportedOpsMap query_model(const std::shared_ptr<const ov::Function>& network,
const std::string& deviceName,
const ConfigMap& config = {}) const;
/**
* @brief Sets configuration for device, acceptable keys can be found in ie_plugin_config.hpp

View File

@ -1306,11 +1306,14 @@ ExecutableNetwork Core::import_model(std::istream& networkModel,
return {exec._so, exec._ptr};
}
ie::QueryNetworkResult Core::query_model(const std::shared_ptr<const ngraph::Function>& network,
const std::string& deviceName,
const ConfigMap& config) const {
return _impl->QueryNetwork(ie::CNNNetwork(std::const_pointer_cast<ngraph::Function>(network)), deviceName, config);
SupportedOpsMap Core::query_model(const std::shared_ptr<const ngraph::Function>& network,
const std::string& deviceName,
const ConfigMap& config) const {
auto cnnNet = ie::CNNNetwork(std::const_pointer_cast<ngraph::Function>(network));
auto qnResult = _impl->QueryNetwork(cnnNet, deviceName, config);
return qnResult.supportedLayersMap;
}
void Core::set_config(const ConfigMap& config, const std::string& deviceName) {
// HETERO case
if (deviceName.find("HETERO:") == 0) {

View File

@ -512,8 +512,7 @@ TEST_P(OVClassNetworkTestP, QueryNetworkWithKSO) {
ov::runtime::Core ie = createCoreWithTemplate();
try {
auto rres = ie.query_model(ksoNetwork, deviceName);
auto rl_map = rres.supportedLayersMap;
auto rl_map = ie.query_model(ksoNetwork, deviceName);
auto func = ksoNetwork;
for (const auto& op : func->get_ops()) {
if (!rl_map.count(op->get_friendly_name())) {
@ -556,8 +555,7 @@ TEST_P(OVClassNetworkTestP, SetAffinityWithConstantBranches) {
func = std::make_shared<ngraph::Function>(results, params);
}
auto rres = ie.query_model(func, deviceName);
auto rl_map = rres.supportedLayersMap;
auto rl_map = ie.query_model(func, deviceName);
for (const auto& op : func->get_ops()) {
if (!rl_map.count(op->get_friendly_name())) {
FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName;
@ -579,8 +577,7 @@ TEST_P(OVClassNetworkTestP, SetAffinityWithKSO) {
ov::runtime::Core ie = createCoreWithTemplate();
try {
auto rres = ie.query_model(ksoNetwork, deviceName);
auto rl_map = rres.supportedLayersMap;
auto rl_map = ie.query_model(ksoNetwork, deviceName);
auto func = ksoNetwork;
for (const auto& op : func->get_ops()) {
if (!rl_map.count(op->get_friendly_name())) {
@ -601,10 +598,10 @@ TEST_P(OVClassNetworkTestP, SetAffinityWithKSO) {
TEST_P(OVClassNetworkTestP, QueryNetworkHeteroActualNoThrow) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
ov::runtime::Core ie = createCoreWithTemplate();
QueryNetworkResult res;
ov::runtime::SupportedOpsMap res;
ASSERT_NO_THROW(
res = ie.query_model(actualNetwork, CommonTestUtils::DEVICE_HETERO, {{"TARGET_FALLBACK", deviceName}}));
ASSERT_LT(0, res.supportedLayersMap.size());
ASSERT_LT(0, res.size());
}
TEST_P(OVClassNetworkTestP, QueryNetworkMultiThrows) {
@ -1408,7 +1405,7 @@ TEST_P(OVClassLoadNetworkTest, QueryNetworkHETEROWithMULTINoThrow_V10) {
for (auto&& node : function->get_ops()) {
expectedLayers.emplace(node->get_friendly_name());
}
QueryNetworkResult result;
ov::runtime::SupportedOpsMap result;
std::string targetFallback(CommonTestUtils::DEVICE_MULTI + std::string(",") + deviceName);
ASSERT_NO_THROW(result = ie.query_model(
multinputNetwork,
@ -1416,7 +1413,7 @@ TEST_P(OVClassLoadNetworkTest, QueryNetworkHETEROWithMULTINoThrow_V10) {
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES), devices}, {"TARGET_FALLBACK", targetFallback}}));
std::unordered_set<std::string> actualLayers;
for (auto&& layer : result.supportedLayersMap) {
for (auto&& layer : result) {
actualLayers.emplace(layer.first);
}
ASSERT_EQ(expectedLayers, actualLayers);
@ -1444,14 +1441,14 @@ TEST_P(OVClassLoadNetworkTest, QueryNetworkMULTIWithHETERONoThrow_V10) {
for (auto&& node : function->get_ops()) {
expectedLayers.emplace(node->get_friendly_name());
}
QueryNetworkResult result;
ov::runtime::SupportedOpsMap result;
ASSERT_NO_THROW(result = ie.query_model(multinputNetwork,
CommonTestUtils::DEVICE_MULTI,
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES), devices},
{"TARGET_FALLBACK", deviceName + "," + deviceName}}));
std::unordered_set<std::string> actualLayers;
for (auto&& layer : result.supportedLayersMap) {
for (auto&& layer : result) {
actualLayers.emplace(layer.first);
}
ASSERT_EQ(expectedLayers, actualLayers);