diff --git a/inference-engine/src/mkldnn_plugin/cpu_types.cpp b/inference-engine/src/mkldnn_plugin/cpu_types.cpp index 83fc153bd06..4af6683bf78 100644 --- a/inference-engine/src/mkldnn_plugin/cpu_types.cpp +++ b/inference-engine/src/mkldnn_plugin/cpu_types.cpp @@ -144,6 +144,7 @@ const InferenceEngine::details::caseless_unordered_map type_t { "Cosh", Math}, { "Floor", Math}, { "HardSigmoid", Math}, + { "If", If}, { "Log", Math}, { "Neg", Math}, { "Reciprocal", Math}, @@ -320,6 +321,8 @@ std::string NameFromType(const Type type) { return "DetectionOutput"; case ExperimentalDetectronDetectionOutput: return "ExperimentalDetectronDetectionOutput"; + case If: + return "If"; case LogSoftmax: return "LogSoftmax"; case TopK: diff --git a/inference-engine/src/mkldnn_plugin/cpu_types.h b/inference-engine/src/mkldnn_plugin/cpu_types.h index 33f675fbb54..95371b6c847 100644 --- a/inference-engine/src/mkldnn_plugin/cpu_types.h +++ b/inference-engine/src/mkldnn_plugin/cpu_types.h @@ -17,6 +17,7 @@ using VectorDims = std::vector; enum Type { Unknown, Generic, + If, Reorder, Input, Output, diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index 9efa0752644..89eafa7e987 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -1127,10 +1128,16 @@ MKLDNNNode* MKLDNNNode::NodesFactory::create(const std::shared_ptr // WA-start : TI node requires all attributes to construct internal subgpath // including extManager, socket and mkldnn::eng. - MKLDNNTensorIteratorNode *ti = dynamic_cast(newNode); - if (ti != nullptr) - ti->setExtManager(extMgr); - // WA-end + if (newNode) { + if (newNode->getType() == TensorIterator) { + if (auto ti = dynamic_cast(newNode)) + ti->setExtManager(extMgr); + } else if (newNode->getType() == If) { + if (auto ifNode = dynamic_cast(newNode)) + ifNode->setExtManager(extMgr); + } + } +// // WA-end if (!newNode) { std::string errorDetails; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp new file mode 100644 index 00000000000..84ebf2101e6 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp @@ -0,0 +1,230 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "mkldnn_if_node.h" + +#include +#include + +#include +#include +#include + +using namespace MKLDNNPlugin; + +MKLDNNIfNode::PortMapHelper::PortMapHelper(const MKLDNNMemoryPtr &from, const MKLDNNMemoryPtr &to, const mkldnn::engine& eng) { + mem_holder_src = from->GetPrimitive(); + mem_holder_dst = to->GetPrimitive(); + reorder = {mem_holder_src, mem_holder_dst}; +} + +void MKLDNNIfNode::PortMapHelper::execute(mkldnn::stream& strm) { + reorder.execute(strm, mem_holder_src, mem_holder_dst); +} + +bool MKLDNNIfNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + if (isDynamicNgraphNode(op)) { + errorMessage = "If node doesn't support op with dynamic shapes"; + return false; + } + if (!one_of(op->get_type_info(), + ov::op::v8::If::type_info)) { + errorMessage = "Not supported If operation version " + std::to_string(op->get_type_info().version) + + " with name '" + op->get_friendly_name() + "'. Node If supports only opset8 version."; + return false; + } + } catch (...) { + return false; + } + return true; +} + +MKLDNNIfNode::MKLDNNIfNode(const std::shared_ptr& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) : + MKLDNNNode(op, eng, cache), ovOp(op) { + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + IE_THROW(NotImplemented) << errorMessage; + } +} + +void MKLDNNIfNode::getSupportedDescriptors() { + auto ifOp = ov::as_type_ptr(ovOp); + + const std::shared_ptr& thenBody = ifOp->get_then_body(); + const std::shared_ptr& elseBody = ifOp->get_else_body(); + subGraphThen.CreateGraph(thenBody, ext_mng, weightCache); + subGraphElse.CreateGraph(elseBody, ext_mng, weightCache); + + const auto &inMapThen = subGraphThen.GetInputNodesMap(); + for (const auto ¶m : ifOp->get_then_body()->get_parameters()) { + auto inNode = inMapThen.find(param->get_friendly_name()); + if (inNode != inMapThen.end()) { + auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr(); + inputMemThen.push_back(inMem); + } else { + IE_THROW() << "Then body of node If with name " << getName() << " does not have input with name: " + << param->get_friendly_name(); + } + } + + const auto &inMapElse = subGraphElse.GetInputNodesMap(); + for (const auto ¶m : ifOp->get_else_body()->get_parameters()) { + auto inNode = inMapElse.find(param->get_friendly_name()); + if (inNode != inMapElse.end()) { + auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr(); + inputMemElse.push_back(inMem); + } else { + IE_THROW() << "Else body of node If with name " << getName() << " does not have input with name: " + << param->get_friendly_name(); + } + } + + const auto &outMapThen = subGraphThen.GetOutputNodesMap(); + for (const auto& out : ifOp->get_then_body()->get_results()) { + auto prev = out->get_input_node_shared_ptr(0); + std::string inputID = prev->get_friendly_name(); + if (prev->get_output_size() > 1) { + inputID += "." + std::to_string(out->get_input_source_output(0).get_index()); + } + auto outNode = outMapThen.find(inputID); + if (outNode != outMapThen.end()) { + auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr(); + outputMemThen.push_back(outMem); + } else { + IE_THROW() << "Then body of node If with name " << getName() << " does not have output with name: " + << inputID; + } + } + + const auto &outMapElse = subGraphElse.GetOutputNodesMap(); + for (const auto& out : ifOp->get_else_body()->get_results()) { + auto prev = out->get_input_node_shared_ptr(0); + std::string inputID = prev->get_friendly_name(); + if (prev->get_output_size() > 1) { + inputID += "." + std::to_string(out->get_input_source_output(0).get_index()); + } + auto outNode = outMapElse.find(inputID); + if (outNode != outMapElse.end()) { + auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr(); + outputMemElse.push_back(outMem); + } else { + IE_THROW() << "Else body of node If with name " << getName() << " does not have output with name: " + << inputID; + } + } + + // Port map: outputs + for (const auto& desc : ifOp->get_output_descriptions(0)) { + auto body_output_idx = desc->m_body_value_index; + thenOutputPortMap.emplace_back(PortMap { + static_cast(desc->m_output_index), static_cast(body_output_idx)}); + } + for (const auto& desc : ifOp->get_output_descriptions(1)) { + auto body_output_idx = desc->m_body_value_index; + elseOutputPortMap.emplace_back(PortMap { + static_cast(desc->m_output_index), static_cast(body_output_idx)}); + } + + for (const auto& desc : ifOp->get_input_descriptions(0)) { + auto body_input_index = desc->m_body_parameter_index; + thenInputPortMap.emplace_back(PortMap { + static_cast(desc->m_input_index), static_cast(body_input_index)}); + } + for (const auto& desc : ifOp->get_input_descriptions(1)) { + auto body_input_index = desc->m_body_parameter_index; + elseInputPortMap.emplace_back(PortMap { + static_cast(desc->m_input_index), static_cast(body_input_index)}); + } +} + +void MKLDNNIfNode::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + + NodeConfig config; + config.inConfs.reserve(getParentEdges().size()); + config.outConfs.reserve(getChildEdges().size()); + + for (size_t i = 0; i < inputShapes.size(); i++) { + auto dims = inputShapes[i].getDims(); + + PortConfig dataConf {}; + auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp); + dataConf.desc = descCreator->createSharedDesc(getOriginalInputPrecisionAtPort(i), Shape(dims)); + config.inConfs.emplace_back(dataConf); + } + + for (size_t i = 0; i < outputShapes.size(); i++) { + auto dims = outputShapes[i].getDims(); + + PortConfig dataConf {}; + auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp); + dataConf.desc = descCreator->createSharedDesc(getOriginalOutputPrecisionAtPort(i), Shape(dims)); + config.outConfs.push_back(dataConf); + } + + config.dynBatchSupport = true; + + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown); +} + + +void MKLDNNIfNode::createPrimitive() { + const auto& eng = getEngine(); + + for (auto& map_rule : thenInputPortMap) { + auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr(); + auto &toMem = inputMemThen[map_rule.to]; + + beforeThenMappers.emplace_back(std::make_shared(fromMem, toMem, eng)); + } + + for (auto& map_rule : elseInputPortMap) { + auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr(); + auto &toMem = inputMemElse[map_rule.to]; + + beforeElseMappers.emplace_back(std::make_shared(fromMem, toMem, eng)); + } + + for (auto& map_rule : thenOutputPortMap) { + auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr(); + auto &fromMem = outputMemThen[map_rule.to]; + + afterThenMappers.emplace_back(std::make_shared(fromMem, toMem, eng)); + } + + for (auto& map_rule : elseOutputPortMap) { + auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr(); + auto &fromMem = outputMemElse[map_rule.to]; + + afterElseMappers.emplace_back(std::make_shared(fromMem, toMem, eng)); + } +} + +void MKLDNNIfNode::execute(mkldnn::stream strm) { + const bool condition = *(reinterpret_cast(getParentEdgeAt(0)->getMemoryPtr()->GetPtr())); + + if (condition) { + for (auto &mapper : beforeThenMappers) + mapper->execute(strm); + subGraphThen.ResetInferCount(); + subGraphThen.Infer(); + for (auto &mapper : afterThenMappers) + mapper->execute(strm); + } else { + for (auto &mapper : beforeElseMappers) + mapper->execute(strm); + subGraphElse.ResetInferCount(); + subGraphElse.Infer(); + for (auto &mapper : afterElseMappers) + mapper->execute(strm); + } +} + +bool MKLDNNIfNode::created() const { + return getType() == If; +} + +REG_MKLDNN_PRIM_FOR(MKLDNNIfNode, If); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h new file mode 100644 index 00000000000..c384945b442 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h @@ -0,0 +1,66 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include +#include +#include + +namespace MKLDNNPlugin { + +class MKLDNNIfNode : public MKLDNNNode { +public: + MKLDNNIfNode(const std::shared_ptr& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache); + + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + void initSupportedPrimitiveDescriptors() override; + void getSupportedDescriptors() override; + void createPrimitive() override; + bool created() const override; + void execute(mkldnn::stream strm) override; + + void inline setExtManager(const MKLDNNExtensionManager::Ptr& extMgr) { ext_mng = extMgr; } + + struct PortMap { + int from; /**< Index of external/internal out data */ + int to; /**< Index of external/internal in data */ + }; + + class PortMapHelper { + public: + PortMapHelper(const MKLDNNMemoryPtr& from, const MKLDNNMemoryPtr& to, const mkldnn::engine& eng); + virtual ~PortMapHelper() = default; + virtual void execute(mkldnn::stream& strm); + protected: + mkldnn::reorder reorder; + mkldnn::memory mem_holder_src; + mkldnn::memory mem_holder_dst; + }; + +private: + MKLDNNExtensionManager::Ptr ext_mng; + MKLDNNGraph subGraphThen; + MKLDNNGraph subGraphElse; + std::deque inputMemThen, inputMemElse, outputMemThen, outputMemElse; + + std::vector> + beforeThenMappers, + beforeElseMappers, + afterThenMappers, + afterElseMappers; + + std::vector + thenInputPortMap, + thenOutputPortMap, + elseInputPortMap, + elseOutputPortMap; + + const std::shared_ptr ovOp; +}; + +} // namespace MKLDNNPlugin diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/subgraph_tests/simple_if.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/subgraph_tests/simple_if.cpp new file mode 100644 index 00000000000..064d8ca9974 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/subgraph_tests/simple_if.cpp @@ -0,0 +1,57 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "shared_test_classes/subgraph/simple_if.hpp" + +using namespace SubgraphTestsDefinitions; + +namespace { + std::vector>> inputShapes = { + {{5, 7}, {5, 7}}, + {{30, 20, 10}, {30, 20, 10}} + }; + + std::vector netPrecisions = {InferenceEngine::Precision::FP32, + InferenceEngine::Precision::I8, + }; + + std::vector conditions = {true, false}; + + INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfTest, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(conditions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + SimpleIfTest::getTestCaseName); + + TEST_P(SimpleIfTest, CompareWithRefs) { + Run(); + }; + + INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIf2OutTest, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(conditions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + SimpleIf2OutTest::getTestCaseName); + + TEST_P(SimpleIf2OutTest, CompareWithRefs) { + Run(); + }; + + INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfNotConstConditionTest, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(conditions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + SimpleIfNotConstConditionTest::getTestCaseName); + + TEST_P(SimpleIfNotConstConditionTest, CompareWithRefs) { + Run(); + }; + +} // namespace diff --git a/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/simple_if.hpp b/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/simple_if.hpp new file mode 100644 index 00000000000..db129e71e66 --- /dev/null +++ b/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/simple_if.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "shared_test_classes/base/layer_test_utils.hpp" + +namespace SubgraphTestsDefinitions { + +using SimpleIfParamsTuple = typename std::tuple< + std::vector>, // Input shapes + InferenceEngine::Precision, // Network precision + bool, // If condition + std::string // Device name +>; + +class SimpleIfTest: + public testing::WithParamInterface, + public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj); +protected: + void SetUp() override; +}; + +class SimpleIf2OutTest : public SimpleIfTest { +protected: + void SetUp() override; +}; + +class SimpleIfNotConstConditionTest : public SimpleIfTest { +public: + InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override; + +protected: + void SetUp() override; + + bool condition; +}; + +} // namespace SubgraphTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/src/subgraph/simple_if.cpp b/inference-engine/tests/functional/shared_test_classes/src/subgraph/simple_if.cpp new file mode 100644 index 00000000000..4ac1d7e43ba --- /dev/null +++ b/inference-engine/tests/functional/shared_test_classes/src/subgraph/simple_if.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "shared_test_classes/subgraph/simple_if.hpp" +#include "ngraph_functions/builders.hpp" + +namespace SubgraphTestsDefinitions { +std::string SimpleIfTest::getTestCaseName(const testing::TestParamInfo &obj) { + std::vector> inputShapes; + InferenceEngine::Precision netPrecision; + bool condition; + std::string targetName; + std::tie(inputShapes, netPrecision, condition, targetName) = obj.param; + std::ostringstream results; + + results << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_"; + results << "netPRC=" << netPrecision.name() << "_"; + results << "Cond=" << condition << "_"; + results << "targetDevice=" << targetName << "_"; + return results.str(); +} + +void SimpleIfTest::SetUp() { + std::vector> inputShapes; + auto netPrecision = InferenceEngine::Precision::UNSPECIFIED; + bool condition; + std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]}); + auto paramOuts = ngraph::helpers::convert2OutputVector( + ngraph::helpers::castOps2Nodes(params)); + + auto p1 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[0])); + auto p2 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[1])); + auto p3 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[0])); + + auto thenOp = std::make_shared(p1, p2); + auto res1 = std::make_shared(thenOp); + auto res2 = std::make_shared(p3); + + auto thenBody = std::make_shared(ov::OutputVector{res1}, ov::ParameterVector{p1, p2}); + auto elseBody = std::make_shared(ov::OutputVector{res2}, ov::ParameterVector{p3}); + + auto condOp = ngraph::builder::makeConstant(ov::element::Type_t::boolean, {1}, {condition}); + auto ifOp = std::make_shared(condOp); + ifOp->set_then_body(thenBody); + ifOp->set_else_body(elseBody); + ifOp->set_input(paramOuts[0], p1, p3); + ifOp->set_input(paramOuts[1], p2, nullptr); + auto res = ifOp->set_output(res1, res2); + + ov::ResultVector results{std::make_shared(res)}; + function = std::make_shared(results, params, "simpleIf"); +} + +void SimpleIf2OutTest::SetUp() { + std::vector> inputShapes; + auto netPrecision = InferenceEngine::Precision::UNSPECIFIED; + bool condition; + std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]}); + auto paramOuts = ngraph::helpers::convert2OutputVector( + ngraph::helpers::castOps2Nodes(params)); + + auto p1 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[0])); + auto p2 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[1])); + auto p3 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[0])); + auto p4 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[1])); + + auto thenOp = std::make_shared(p1, p2); + auto res1 = std::make_shared(thenOp); + auto res2 = std::make_shared(thenOp); + auto res3 = std::make_shared(p3); + auto res4 = std::make_shared(p4); + + auto thenBody = std::make_shared(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2}); + auto elseBody = std::make_shared(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4}); + + auto condOp = ngraph::builder::makeConstant(ov::element::Type_t::boolean, {1}, {condition}); + auto ifOp = std::make_shared(condOp); + ifOp->set_then_body(thenBody); + ifOp->set_else_body(elseBody); + ifOp->set_input(paramOuts[0], p1, p3); + ifOp->set_input(paramOuts[1], p2, p4); + auto ifRes1 = ifOp->set_output(res1, res3); + auto ifRes2 = ifOp->set_output(res2, res4); + + ov::ResultVector results{std::make_shared(ifRes1), std::make_shared(ifRes2)}; + function = std::make_shared(results, params, "simpleIf2Out"); +} + +void SimpleIfNotConstConditionTest::SetUp() { + std::vector> inputShapes; + auto netPrecision = InferenceEngine::Precision::UNSPECIFIED; + std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + ov::ParameterVector params { + ngraph::builder::makeParams(ngPrc, {inputShapes[0]})[0], + ngraph::builder::makeParams(ngPrc, {inputShapes[1]})[0], + ngraph::builder::makeParams(ov::element::boolean, { {"condition", {1}} })[0] + }; + auto paramOuts = ngraph::helpers::convert2OutputVector( + ngraph::helpers::castOps2Nodes(params)); + + auto p1 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[0])); + auto p2 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[1])); + auto p3 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[0])); + auto p4 = std::make_shared(ngPrc, ov::PartialShape(inputShapes[1])); + + auto thenOp = std::make_shared(p1, p2); + auto res1 = std::make_shared(thenOp); + auto res2 = std::make_shared(thenOp); + auto res3 = std::make_shared(p3); + auto res4 = std::make_shared(p4); + + auto thenBody = std::make_shared(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2}); + auto elseBody = std::make_shared(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4}); + + auto ifOp = std::make_shared(paramOuts[2]); + ifOp->set_then_body(thenBody); + ifOp->set_else_body(elseBody); + ifOp->set_input(paramOuts[0], p1, p3); + ifOp->set_input(paramOuts[1], p2, p4); + auto ifRes1 = ifOp->set_output(res1, res3); + auto ifRes2 = ifOp->set_output(res2, res4); + + ov::ResultVector results{std::make_shared(ifRes1), std::make_shared(ifRes2)}; + function = std::make_shared(results, params, "SimpleIfNotConstConditionTest"); +} + +InferenceEngine::Blob::Ptr SimpleIfNotConstConditionTest::GenerateInput(const InferenceEngine::InputInfo& info) const { + if (info.name() == "condition") { + bool conditionArr[1] = { condition }; + return FuncTestUtils::createAndFillBlobWithFloatArray(info.getTensorDesc(), conditionArr, 1); + } else { + return FuncTestUtils::createAndFillBlob(info.getTensorDesc()); + } +} + +} // namespace SubgraphTestsDefinitions diff --git a/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/utils/data_utils.hpp b/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/utils/data_utils.hpp index e01adb0c3c6..22cf4dd53b0 100644 --- a/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/utils/data_utils.hpp +++ b/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/utils/data_utils.hpp @@ -83,7 +83,7 @@ template std::vector castVector(const std::vector &vec) { std::vector resVec; resVec.reserve(vec.size()); - for (auto &el : vec) { + for (const auto &el : vec) { resVec.push_back(static_cast(el)); } return resVec; diff --git a/ngraph/core/include/openvino/op/util/multi_subgraph_base.hpp b/ngraph/core/include/openvino/op/util/multi_subgraph_base.hpp index fc81a70faa4..229f6a37036 100644 --- a/ngraph/core/include/openvino/op/util/multi_subgraph_base.hpp +++ b/ngraph/core/include/openvino/op/util/multi_subgraph_base.hpp @@ -273,6 +273,13 @@ public: /// \param bodies_results vector of bodies results for one output. /// \return value Output node for bodies_results. virtual Output set_body_outputs(const ResultVector& bodies_results); + /// + /// \brief Get number of internal sub-graphs + /// + /// \return Number of sub-graphs. + virtual size_t get_num_internal_subgraphs() const { + return m_bodies.size(); + } MultiSubGraphOp(const MultiSubGraphOp&) = delete; MultiSubGraphOp(MultiSubGraphOp&&) = default; diff --git a/ngraph/core/src/pass/constant_folding.cpp b/ngraph/core/src/pass/constant_folding.cpp index 8101f25dc12..d370b7a9319 100644 --- a/ngraph/core/src/pass/constant_folding.cpp +++ b/ngraph/core/src/pass/constant_folding.cpp @@ -45,9 +45,10 @@ bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr f) } } else { // recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop) - if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { - if (const auto& sub_graph = sub_graph_node->get_function()) { - rewritten |= run_on_function(sub_graph); + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { + size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs(); + for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) { + rewritten |= run_on_function(sub_graph_node->get_function(sub_graph_ind)); } } } diff --git a/ngraph/core/src/pass/convert_precision.cpp b/ngraph/core/src/pass/convert_precision.cpp index 1dbf7f9747f..a569fff516b 100644 --- a/ngraph/core/src/pass/convert_precision.cpp +++ b/ngraph/core/src/pass/convert_precision.cpp @@ -161,9 +161,10 @@ bool convert_precision(pass::PassBase& pass, for (auto& node : ops) { pass.transformation_callback(node); // Recursively apply transformation for sub-graph based operations - if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { - if (auto sub_graph = sub_graph_node->get_function()) { - is_changed |= convert_function_precision(sub_graph, true); + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { + size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs(); + for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) { + is_changed |= convert_function_precision(sub_graph_node->get_function(sub_graph_ind), true); } } is_changed |= convert_node_input_precision(node); @@ -226,9 +227,10 @@ precisions_set_t find_all_used_precisions(const std::shared_ptroutputs()) { used_precisions.emplace(output.get_element_type()); } - if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { - if (auto sub_graph = sub_graph_node->get_function()) { - auto sub_graph_precisions = find_all_used_precisions(sub_graph); + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { + size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs(); + for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) { + auto sub_graph_precisions = find_all_used_precisions(sub_graph_node->get_function(sub_graph_ind)); used_precisions.insert(sub_graph_precisions.begin(), sub_graph_precisions.end()); } } diff --git a/ngraph/core/src/pass/graph_rewrite.cpp b/ngraph/core/src/pass/graph_rewrite.cpp index 7124a5b3f5b..c88f010e265 100644 --- a/ngraph/core/src/pass/graph_rewrite.cpp +++ b/ngraph/core/src/pass/graph_rewrite.cpp @@ -172,8 +172,10 @@ bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr f, continue; // Recursive apply Matchers for sub-graph based nodes - if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { - if (auto sub_graph = sub_graph_node->get_function()) { + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { + size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs(); + for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) { + auto sub_graph = sub_graph_node->get_function(sub_graph_ind); run_on_function(sub_graph); } }