[CPU] If-8 operation implementation. (#7253)

This commit is contained in:
Nikolay Shchegolev 2021-10-13 20:45:24 +03:00 committed by GitHub
parent 7b1a418bf4
commit 0ed1c24cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 579 additions and 16 deletions

View File

@ -144,6 +144,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "Cosh", Math},
{ "Floor", Math},
{ "HardSigmoid", Math},
{ "If", If},
{ "Log", Math},
{ "Neg", Math},
{ "Reciprocal", Math},
@ -320,6 +321,8 @@ std::string NameFromType(const Type type) {
return "DetectionOutput";
case ExperimentalDetectronDetectionOutput:
return "ExperimentalDetectronDetectionOutput";
case If:
return "If";
case LogSoftmax:
return "LogSoftmax";
case TopK:

View File

@ -17,6 +17,7 @@ using VectorDims = std::vector<Dim>;
enum Type {
Unknown,
Generic,
If,
Reorder,
Input,
Output,

View File

@ -21,6 +21,7 @@
#include <nodes/mkldnn_matmul_node.h>
#include <nodes/mkldnn_fullyconnected_node.h>
#include <nodes/mkldnn_generic_node.h>
#include <nodes/mkldnn_if_node.h>
#include <nodes/mkldnn_input_node.h>
#include <nodes/mkldnn_lrn_node.h>
#include <nodes/mkldnn_pooling_node.h>
@ -1127,10 +1128,16 @@ MKLDNNNode* MKLDNNNode::NodesFactory::create(const std::shared_ptr<ngraph::Node>
// WA-start : TI node requires all attributes to construct internal subgpath
// including extManager, socket and mkldnn::eng.
MKLDNNTensorIteratorNode *ti = dynamic_cast<MKLDNNTensorIteratorNode*>(newNode);
if (ti != nullptr)
if (newNode) {
if (newNode->getType() == TensorIterator) {
if (auto ti = dynamic_cast<MKLDNNTensorIteratorNode*>(newNode))
ti->setExtManager(extMgr);
// WA-end
} else if (newNode->getType() == If) {
if (auto ifNode = dynamic_cast<MKLDNNIfNode*>(newNode))
ifNode->setExtManager(extMgr);
}
}
// // WA-end
if (!newNode) {
std::string errorDetails;

View File

@ -0,0 +1,230 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mkldnn_if_node.h"
#include <mkldnn_extension_utils.h>
#include <ie_ngraph_utils.hpp>
#include <map>
#include <string>
#include <vector>
using namespace MKLDNNPlugin;
MKLDNNIfNode::PortMapHelper::PortMapHelper(const MKLDNNMemoryPtr &from, const MKLDNNMemoryPtr &to, const mkldnn::engine& eng) {
mem_holder_src = from->GetPrimitive();
mem_holder_dst = to->GetPrimitive();
reorder = {mem_holder_src, mem_holder_dst};
}
void MKLDNNIfNode::PortMapHelper::execute(mkldnn::stream& strm) {
reorder.execute(strm, mem_holder_src, mem_holder_dst);
}
bool MKLDNNIfNode::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (isDynamicNgraphNode(op)) {
errorMessage = "If node doesn't support op with dynamic shapes";
return false;
}
if (!one_of(op->get_type_info(),
ov::op::v8::If::type_info)) {
errorMessage = "Not supported If operation version " + std::to_string(op->get_type_info().version) +
" with name '" + op->get_friendly_name() + "'. Node If supports only opset8 version.";
return false;
}
} catch (...) {
return false;
}
return true;
}
MKLDNNIfNode::MKLDNNIfNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) :
MKLDNNNode(op, eng, cache), ovOp(op) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage;
}
}
void MKLDNNIfNode::getSupportedDescriptors() {
auto ifOp = ov::as_type_ptr<ov::op::v8::If>(ovOp);
const std::shared_ptr<const ov::Function>& thenBody = ifOp->get_then_body();
const std::shared_ptr<const ov::Function>& elseBody = ifOp->get_else_body();
subGraphThen.CreateGraph(thenBody, ext_mng, weightCache);
subGraphElse.CreateGraph(elseBody, ext_mng, weightCache);
const auto &inMapThen = subGraphThen.GetInputNodesMap();
for (const auto &param : ifOp->get_then_body()->get_parameters()) {
auto inNode = inMapThen.find(param->get_friendly_name());
if (inNode != inMapThen.end()) {
auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr();
inputMemThen.push_back(inMem);
} else {
IE_THROW() << "Then body of node If with name " << getName() << " does not have input with name: "
<< param->get_friendly_name();
}
}
const auto &inMapElse = subGraphElse.GetInputNodesMap();
for (const auto &param : ifOp->get_else_body()->get_parameters()) {
auto inNode = inMapElse.find(param->get_friendly_name());
if (inNode != inMapElse.end()) {
auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr();
inputMemElse.push_back(inMem);
} else {
IE_THROW() << "Else body of node If with name " << getName() << " does not have input with name: "
<< param->get_friendly_name();
}
}
const auto &outMapThen = subGraphThen.GetOutputNodesMap();
for (const auto& out : ifOp->get_then_body()->get_results()) {
auto prev = out->get_input_node_shared_ptr(0);
std::string inputID = prev->get_friendly_name();
if (prev->get_output_size() > 1) {
inputID += "." + std::to_string(out->get_input_source_output(0).get_index());
}
auto outNode = outMapThen.find(inputID);
if (outNode != outMapThen.end()) {
auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr();
outputMemThen.push_back(outMem);
} else {
IE_THROW() << "Then body of node If with name " << getName() << " does not have output with name: "
<< inputID;
}
}
const auto &outMapElse = subGraphElse.GetOutputNodesMap();
for (const auto& out : ifOp->get_else_body()->get_results()) {
auto prev = out->get_input_node_shared_ptr(0);
std::string inputID = prev->get_friendly_name();
if (prev->get_output_size() > 1) {
inputID += "." + std::to_string(out->get_input_source_output(0).get_index());
}
auto outNode = outMapElse.find(inputID);
if (outNode != outMapElse.end()) {
auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr();
outputMemElse.push_back(outMem);
} else {
IE_THROW() << "Else body of node If with name " << getName() << " does not have output with name: "
<< inputID;
}
}
// Port map: outputs
for (const auto& desc : ifOp->get_output_descriptions(0)) {
auto body_output_idx = desc->m_body_value_index;
thenOutputPortMap.emplace_back(PortMap {
static_cast<int>(desc->m_output_index), static_cast<int>(body_output_idx)});
}
for (const auto& desc : ifOp->get_output_descriptions(1)) {
auto body_output_idx = desc->m_body_value_index;
elseOutputPortMap.emplace_back(PortMap {
static_cast<int>(desc->m_output_index), static_cast<int>(body_output_idx)});
}
for (const auto& desc : ifOp->get_input_descriptions(0)) {
auto body_input_index = desc->m_body_parameter_index;
thenInputPortMap.emplace_back(PortMap {
static_cast<int>(desc->m_input_index), static_cast<int>(body_input_index)});
}
for (const auto& desc : ifOp->get_input_descriptions(1)) {
auto body_input_index = desc->m_body_parameter_index;
elseInputPortMap.emplace_back(PortMap {
static_cast<int>(desc->m_input_index), static_cast<int>(body_input_index)});
}
}
void MKLDNNIfNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
NodeConfig config;
config.inConfs.reserve(getParentEdges().size());
config.outConfs.reserve(getChildEdges().size());
for (size_t i = 0; i < inputShapes.size(); i++) {
auto dims = inputShapes[i].getDims();
PortConfig dataConf {};
auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp);
dataConf.desc = descCreator->createSharedDesc(getOriginalInputPrecisionAtPort(i), Shape(dims));
config.inConfs.emplace_back(dataConf);
}
for (size_t i = 0; i < outputShapes.size(); i++) {
auto dims = outputShapes[i].getDims();
PortConfig dataConf {};
auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp);
dataConf.desc = descCreator->createSharedDesc(getOriginalOutputPrecisionAtPort(i), Shape(dims));
config.outConfs.push_back(dataConf);
}
config.dynBatchSupport = true;
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}
void MKLDNNIfNode::createPrimitive() {
const auto& eng = getEngine();
for (auto& map_rule : thenInputPortMap) {
auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
auto &toMem = inputMemThen[map_rule.to];
beforeThenMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
}
for (auto& map_rule : elseInputPortMap) {
auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
auto &toMem = inputMemElse[map_rule.to];
beforeElseMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
}
for (auto& map_rule : thenOutputPortMap) {
auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
auto &fromMem = outputMemThen[map_rule.to];
afterThenMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
}
for (auto& map_rule : elseOutputPortMap) {
auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
auto &fromMem = outputMemElse[map_rule.to];
afterElseMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
}
}
void MKLDNNIfNode::execute(mkldnn::stream strm) {
const bool condition = *(reinterpret_cast<const bool*>(getParentEdgeAt(0)->getMemoryPtr()->GetPtr()));
if (condition) {
for (auto &mapper : beforeThenMappers)
mapper->execute(strm);
subGraphThen.ResetInferCount();
subGraphThen.Infer();
for (auto &mapper : afterThenMappers)
mapper->execute(strm);
} else {
for (auto &mapper : beforeElseMappers)
mapper->execute(strm);
subGraphElse.ResetInferCount();
subGraphElse.Infer();
for (auto &mapper : afterElseMappers)
mapper->execute(strm);
}
}
bool MKLDNNIfNode::created() const {
return getType() == If;
}
REG_MKLDNN_PRIM_FOR(MKLDNNIfNode, If);

View File

@ -0,0 +1,66 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <mkldnn_node.h>
#include <mkldnn_graph.h>
#include <memory>
#include <string>
#include <vector>
namespace MKLDNNPlugin {
class MKLDNNIfNode : public MKLDNNNode {
public:
MKLDNNIfNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void initSupportedPrimitiveDescriptors() override;
void getSupportedDescriptors() override;
void createPrimitive() override;
bool created() const override;
void execute(mkldnn::stream strm) override;
void inline setExtManager(const MKLDNNExtensionManager::Ptr& extMgr) { ext_mng = extMgr; }
struct PortMap {
int from; /**< Index of external/internal out data */
int to; /**< Index of external/internal in data */
};
class PortMapHelper {
public:
PortMapHelper(const MKLDNNMemoryPtr& from, const MKLDNNMemoryPtr& to, const mkldnn::engine& eng);
virtual ~PortMapHelper() = default;
virtual void execute(mkldnn::stream& strm);
protected:
mkldnn::reorder reorder;
mkldnn::memory mem_holder_src;
mkldnn::memory mem_holder_dst;
};
private:
MKLDNNExtensionManager::Ptr ext_mng;
MKLDNNGraph subGraphThen;
MKLDNNGraph subGraphElse;
std::deque<MKLDNNMemoryPtr> inputMemThen, inputMemElse, outputMemThen, outputMemElse;
std::vector<std::shared_ptr<PortMapHelper>>
beforeThenMappers,
beforeElseMappers,
afterThenMappers,
afterElseMappers;
std::vector<PortMap>
thenInputPortMap,
thenOutputPortMap,
elseInputPortMap,
elseOutputPortMap;
const std::shared_ptr<ov::Node> ovOp;
};
} // namespace MKLDNNPlugin

View File

@ -0,0 +1,57 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/subgraph/simple_if.hpp"
using namespace SubgraphTestsDefinitions;
namespace {
std::vector<std::vector<std::vector<size_t>>> inputShapes = {
{{5, 7}, {5, 7}},
{{30, 20, 10}, {30, 20, 10}}
};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::I8,
};
std::vector<bool> conditions = {true, false};
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(conditions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
SimpleIfTest::getTestCaseName);
TEST_P(SimpleIfTest, CompareWithRefs) {
Run();
};
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIf2OutTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(conditions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
SimpleIf2OutTest::getTestCaseName);
TEST_P(SimpleIf2OutTest, CompareWithRefs) {
Run();
};
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfNotConstConditionTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(conditions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
SimpleIfNotConstConditionTest::getTestCaseName);
TEST_P(SimpleIfNotConstConditionTest, CompareWithRefs) {
Run();
};
} // namespace

View File

@ -0,0 +1,45 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include "shared_test_classes/base/layer_test_utils.hpp"
namespace SubgraphTestsDefinitions {
using SimpleIfParamsTuple = typename std::tuple<
std::vector<std::vector<size_t>>, // Input shapes
InferenceEngine::Precision, // Network precision
bool, // If condition
std::string // Device name
>;
class SimpleIfTest:
public testing::WithParamInterface<SimpleIfParamsTuple>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<SimpleIfParamsTuple> &obj);
protected:
void SetUp() override;
};
class SimpleIf2OutTest : public SimpleIfTest {
protected:
void SetUp() override;
};
class SimpleIfNotConstConditionTest : public SimpleIfTest {
public:
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override;
protected:
void SetUp() override;
bool condition;
};
} // namespace SubgraphTestsDefinitions

View File

@ -0,0 +1,142 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/subgraph/simple_if.hpp"
#include "ngraph_functions/builders.hpp"
namespace SubgraphTestsDefinitions {
std::string SimpleIfTest::getTestCaseName(const testing::TestParamInfo<SimpleIfParamsTuple> &obj) {
std::vector<std::vector<size_t>> inputShapes;
InferenceEngine::Precision netPrecision;
bool condition;
std::string targetName;
std::tie(inputShapes, netPrecision, condition, targetName) = obj.param;
std::ostringstream results;
results << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
results << "netPRC=" << netPrecision.name() << "_";
results << "Cond=" << condition << "_";
results << "targetDevice=" << targetName << "_";
return results.str();
}
void SimpleIfTest::SetUp() {
std::vector<std::vector<size_t>> inputShapes;
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
bool condition;
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
auto paramOuts = ngraph::helpers::convert2OutputVector(
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
auto res2 = std::make_shared<ov::op::v0::Result>(p3);
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1}, ov::ParameterVector{p1, p2});
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res2}, ov::ParameterVector{p3});
auto condOp = ngraph::builder::makeConstant<bool>(ov::element::Type_t::boolean, {1}, {condition});
auto ifOp = std::make_shared<ov::op::v8::If>(condOp);
ifOp->set_then_body(thenBody);
ifOp->set_else_body(elseBody);
ifOp->set_input(paramOuts[0], p1, p3);
ifOp->set_input(paramOuts[1], p2, nullptr);
auto res = ifOp->set_output(res1, res2);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(res)};
function = std::make_shared<ov::Function>(results, params, "simpleIf");
}
void SimpleIf2OutTest::SetUp() {
std::vector<std::vector<size_t>> inputShapes;
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
bool condition;
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
auto paramOuts = ngraph::helpers::convert2OutputVector(
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
auto p4 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
auto res2 = std::make_shared<ov::op::v0::Result>(thenOp);
auto res3 = std::make_shared<ov::op::v0::Result>(p3);
auto res4 = std::make_shared<ov::op::v0::Result>(p4);
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2});
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4});
auto condOp = ngraph::builder::makeConstant<bool>(ov::element::Type_t::boolean, {1}, {condition});
auto ifOp = std::make_shared<ov::op::v8::If>(condOp);
ifOp->set_then_body(thenBody);
ifOp->set_else_body(elseBody);
ifOp->set_input(paramOuts[0], p1, p3);
ifOp->set_input(paramOuts[1], p2, p4);
auto ifRes1 = ifOp->set_output(res1, res3);
auto ifRes2 = ifOp->set_output(res2, res4);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(ifRes1), std::make_shared<ov::op::v0::Result>(ifRes2)};
function = std::make_shared<ov::Function>(results, params, "simpleIf2Out");
}
void SimpleIfNotConstConditionTest::SetUp() {
std::vector<std::vector<size_t>> inputShapes;
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
ov::ParameterVector params {
ngraph::builder::makeParams(ngPrc, {inputShapes[0]})[0],
ngraph::builder::makeParams(ngPrc, {inputShapes[1]})[0],
ngraph::builder::makeParams(ov::element::boolean, { {"condition", {1}} })[0]
};
auto paramOuts = ngraph::helpers::convert2OutputVector(
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
auto p4 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
auto res2 = std::make_shared<ov::op::v0::Result>(thenOp);
auto res3 = std::make_shared<ov::op::v0::Result>(p3);
auto res4 = std::make_shared<ov::op::v0::Result>(p4);
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2});
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4});
auto ifOp = std::make_shared<ov::op::v8::If>(paramOuts[2]);
ifOp->set_then_body(thenBody);
ifOp->set_else_body(elseBody);
ifOp->set_input(paramOuts[0], p1, p3);
ifOp->set_input(paramOuts[1], p2, p4);
auto ifRes1 = ifOp->set_output(res1, res3);
auto ifRes2 = ifOp->set_output(res2, res4);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(ifRes1), std::make_shared<ov::op::v0::Result>(ifRes2)};
function = std::make_shared<ov::Function>(results, params, "SimpleIfNotConstConditionTest");
}
InferenceEngine::Blob::Ptr SimpleIfNotConstConditionTest::GenerateInput(const InferenceEngine::InputInfo& info) const {
if (info.name() == "condition") {
bool conditionArr[1] = { condition };
return FuncTestUtils::createAndFillBlobWithFloatArray(info.getTensorDesc(), conditionArr, 1);
} else {
return FuncTestUtils::createAndFillBlob(info.getTensorDesc());
}
}
} // namespace SubgraphTestsDefinitions

View File

@ -83,7 +83,7 @@ template<typename fromType, typename toType>
std::vector<toType> castVector(const std::vector<fromType> &vec) {
std::vector<toType> resVec;
resVec.reserve(vec.size());
for (auto &el : vec) {
for (const auto &el : vec) {
resVec.push_back(static_cast<toType>(el));
}
return resVec;

View File

@ -273,6 +273,13 @@ public:
/// \param bodies_results vector of bodies results for one output.
/// \return value Output node for bodies_results.
virtual Output<Node> set_body_outputs(const ResultVector& bodies_results);
///
/// \brief Get number of internal sub-graphs
///
/// \return Number of sub-graphs.
virtual size_t get_num_internal_subgraphs() const {
return m_bodies.size();
}
MultiSubGraphOp(const MultiSubGraphOp&) = delete;
MultiSubGraphOp(MultiSubGraphOp&&) = default;

View File

@ -45,9 +45,10 @@ bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr<ov::Function> f)
}
} else {
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
if (const auto& sub_graph = sub_graph_node->get_function()) {
rewritten |= run_on_function(sub_graph);
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
rewritten |= run_on_function(sub_graph_node->get_function(sub_graph_ind));
}
}
}

View File

@ -161,9 +161,10 @@ bool convert_precision(pass::PassBase& pass,
for (auto& node : ops) {
pass.transformation_callback(node);
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
is_changed |= convert_function_precision(sub_graph, true);
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
is_changed |= convert_function_precision(sub_graph_node->get_function(sub_graph_ind), true);
}
}
is_changed |= convert_node_input_precision(node);
@ -226,9 +227,10 @@ precisions_set_t find_all_used_precisions(const std::shared_ptr<ngraph::Function
for (const auto& output : node->outputs()) {
used_precisions.emplace(output.get_element_type());
}
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
auto sub_graph_precisions = find_all_used_precisions(sub_graph);
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
auto sub_graph_precisions = find_all_used_precisions(sub_graph_node->get_function(sub_graph_ind));
used_precisions.insert(sub_graph_precisions.begin(), sub_graph_precisions.end());
}
}

View File

@ -172,8 +172,10 @@ bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
continue;
// Recursive apply Matchers for sub-graph based nodes
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
auto sub_graph = sub_graph_node->get_function(sub_graph_ind);
run_on_function(sub_graph);
}
}