[CPU] If-8 operation implementation. (#7253)
This commit is contained in:
parent
7b1a418bf4
commit
0ed1c24cd2
@ -144,6 +144,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
||||
{ "Cosh", Math},
|
||||
{ "Floor", Math},
|
||||
{ "HardSigmoid", Math},
|
||||
{ "If", If},
|
||||
{ "Log", Math},
|
||||
{ "Neg", Math},
|
||||
{ "Reciprocal", Math},
|
||||
@ -320,6 +321,8 @@ std::string NameFromType(const Type type) {
|
||||
return "DetectionOutput";
|
||||
case ExperimentalDetectronDetectionOutput:
|
||||
return "ExperimentalDetectronDetectionOutput";
|
||||
case If:
|
||||
return "If";
|
||||
case LogSoftmax:
|
||||
return "LogSoftmax";
|
||||
case TopK:
|
||||
|
@ -17,6 +17,7 @@ using VectorDims = std::vector<Dim>;
|
||||
enum Type {
|
||||
Unknown,
|
||||
Generic,
|
||||
If,
|
||||
Reorder,
|
||||
Input,
|
||||
Output,
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <nodes/mkldnn_matmul_node.h>
|
||||
#include <nodes/mkldnn_fullyconnected_node.h>
|
||||
#include <nodes/mkldnn_generic_node.h>
|
||||
#include <nodes/mkldnn_if_node.h>
|
||||
#include <nodes/mkldnn_input_node.h>
|
||||
#include <nodes/mkldnn_lrn_node.h>
|
||||
#include <nodes/mkldnn_pooling_node.h>
|
||||
@ -1127,10 +1128,16 @@ MKLDNNNode* MKLDNNNode::NodesFactory::create(const std::shared_ptr<ngraph::Node>
|
||||
|
||||
// WA-start : TI node requires all attributes to construct internal subgpath
|
||||
// including extManager, socket and mkldnn::eng.
|
||||
MKLDNNTensorIteratorNode *ti = dynamic_cast<MKLDNNTensorIteratorNode*>(newNode);
|
||||
if (ti != nullptr)
|
||||
ti->setExtManager(extMgr);
|
||||
// WA-end
|
||||
if (newNode) {
|
||||
if (newNode->getType() == TensorIterator) {
|
||||
if (auto ti = dynamic_cast<MKLDNNTensorIteratorNode*>(newNode))
|
||||
ti->setExtManager(extMgr);
|
||||
} else if (newNode->getType() == If) {
|
||||
if (auto ifNode = dynamic_cast<MKLDNNIfNode*>(newNode))
|
||||
ifNode->setExtManager(extMgr);
|
||||
}
|
||||
}
|
||||
// // WA-end
|
||||
|
||||
if (!newNode) {
|
||||
std::string errorDetails;
|
||||
|
230
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp
Normal file
230
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp
Normal file
@ -0,0 +1,230 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "mkldnn_if_node.h"
|
||||
|
||||
#include <mkldnn_extension_utils.h>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace MKLDNNPlugin;
|
||||
|
||||
MKLDNNIfNode::PortMapHelper::PortMapHelper(const MKLDNNMemoryPtr &from, const MKLDNNMemoryPtr &to, const mkldnn::engine& eng) {
|
||||
mem_holder_src = from->GetPrimitive();
|
||||
mem_holder_dst = to->GetPrimitive();
|
||||
reorder = {mem_holder_src, mem_holder_dst};
|
||||
}
|
||||
|
||||
void MKLDNNIfNode::PortMapHelper::execute(mkldnn::stream& strm) {
|
||||
reorder.execute(strm, mem_holder_src, mem_holder_dst);
|
||||
}
|
||||
|
||||
bool MKLDNNIfNode::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (isDynamicNgraphNode(op)) {
|
||||
errorMessage = "If node doesn't support op with dynamic shapes";
|
||||
return false;
|
||||
}
|
||||
if (!one_of(op->get_type_info(),
|
||||
ov::op::v8::If::type_info)) {
|
||||
errorMessage = "Not supported If operation version " + std::to_string(op->get_type_info().version) +
|
||||
" with name '" + op->get_friendly_name() + "'. Node If supports only opset8 version.";
|
||||
return false;
|
||||
}
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MKLDNNIfNode::MKLDNNIfNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) :
|
||||
MKLDNNNode(op, eng, cache), ovOp(op) {
|
||||
std::string errorMessage;
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
IE_THROW(NotImplemented) << errorMessage;
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNIfNode::getSupportedDescriptors() {
|
||||
auto ifOp = ov::as_type_ptr<ov::op::v8::If>(ovOp);
|
||||
|
||||
const std::shared_ptr<const ov::Function>& thenBody = ifOp->get_then_body();
|
||||
const std::shared_ptr<const ov::Function>& elseBody = ifOp->get_else_body();
|
||||
subGraphThen.CreateGraph(thenBody, ext_mng, weightCache);
|
||||
subGraphElse.CreateGraph(elseBody, ext_mng, weightCache);
|
||||
|
||||
const auto &inMapThen = subGraphThen.GetInputNodesMap();
|
||||
for (const auto ¶m : ifOp->get_then_body()->get_parameters()) {
|
||||
auto inNode = inMapThen.find(param->get_friendly_name());
|
||||
if (inNode != inMapThen.end()) {
|
||||
auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr();
|
||||
inputMemThen.push_back(inMem);
|
||||
} else {
|
||||
IE_THROW() << "Then body of node If with name " << getName() << " does not have input with name: "
|
||||
<< param->get_friendly_name();
|
||||
}
|
||||
}
|
||||
|
||||
const auto &inMapElse = subGraphElse.GetInputNodesMap();
|
||||
for (const auto ¶m : ifOp->get_else_body()->get_parameters()) {
|
||||
auto inNode = inMapElse.find(param->get_friendly_name());
|
||||
if (inNode != inMapElse.end()) {
|
||||
auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr();
|
||||
inputMemElse.push_back(inMem);
|
||||
} else {
|
||||
IE_THROW() << "Else body of node If with name " << getName() << " does not have input with name: "
|
||||
<< param->get_friendly_name();
|
||||
}
|
||||
}
|
||||
|
||||
const auto &outMapThen = subGraphThen.GetOutputNodesMap();
|
||||
for (const auto& out : ifOp->get_then_body()->get_results()) {
|
||||
auto prev = out->get_input_node_shared_ptr(0);
|
||||
std::string inputID = prev->get_friendly_name();
|
||||
if (prev->get_output_size() > 1) {
|
||||
inputID += "." + std::to_string(out->get_input_source_output(0).get_index());
|
||||
}
|
||||
auto outNode = outMapThen.find(inputID);
|
||||
if (outNode != outMapThen.end()) {
|
||||
auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr();
|
||||
outputMemThen.push_back(outMem);
|
||||
} else {
|
||||
IE_THROW() << "Then body of node If with name " << getName() << " does not have output with name: "
|
||||
<< inputID;
|
||||
}
|
||||
}
|
||||
|
||||
const auto &outMapElse = subGraphElse.GetOutputNodesMap();
|
||||
for (const auto& out : ifOp->get_else_body()->get_results()) {
|
||||
auto prev = out->get_input_node_shared_ptr(0);
|
||||
std::string inputID = prev->get_friendly_name();
|
||||
if (prev->get_output_size() > 1) {
|
||||
inputID += "." + std::to_string(out->get_input_source_output(0).get_index());
|
||||
}
|
||||
auto outNode = outMapElse.find(inputID);
|
||||
if (outNode != outMapElse.end()) {
|
||||
auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr();
|
||||
outputMemElse.push_back(outMem);
|
||||
} else {
|
||||
IE_THROW() << "Else body of node If with name " << getName() << " does not have output with name: "
|
||||
<< inputID;
|
||||
}
|
||||
}
|
||||
|
||||
// Port map: outputs
|
||||
for (const auto& desc : ifOp->get_output_descriptions(0)) {
|
||||
auto body_output_idx = desc->m_body_value_index;
|
||||
thenOutputPortMap.emplace_back(PortMap {
|
||||
static_cast<int>(desc->m_output_index), static_cast<int>(body_output_idx)});
|
||||
}
|
||||
for (const auto& desc : ifOp->get_output_descriptions(1)) {
|
||||
auto body_output_idx = desc->m_body_value_index;
|
||||
elseOutputPortMap.emplace_back(PortMap {
|
||||
static_cast<int>(desc->m_output_index), static_cast<int>(body_output_idx)});
|
||||
}
|
||||
|
||||
for (const auto& desc : ifOp->get_input_descriptions(0)) {
|
||||
auto body_input_index = desc->m_body_parameter_index;
|
||||
thenInputPortMap.emplace_back(PortMap {
|
||||
static_cast<int>(desc->m_input_index), static_cast<int>(body_input_index)});
|
||||
}
|
||||
for (const auto& desc : ifOp->get_input_descriptions(1)) {
|
||||
auto body_input_index = desc->m_body_parameter_index;
|
||||
elseInputPortMap.emplace_back(PortMap {
|
||||
static_cast<int>(desc->m_input_index), static_cast<int>(body_input_index)});
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNIfNode::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
NodeConfig config;
|
||||
config.inConfs.reserve(getParentEdges().size());
|
||||
config.outConfs.reserve(getChildEdges().size());
|
||||
|
||||
for (size_t i = 0; i < inputShapes.size(); i++) {
|
||||
auto dims = inputShapes[i].getDims();
|
||||
|
||||
PortConfig dataConf {};
|
||||
auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp);
|
||||
dataConf.desc = descCreator->createSharedDesc(getOriginalInputPrecisionAtPort(i), Shape(dims));
|
||||
config.inConfs.emplace_back(dataConf);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputShapes.size(); i++) {
|
||||
auto dims = outputShapes[i].getDims();
|
||||
|
||||
PortConfig dataConf {};
|
||||
auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp);
|
||||
dataConf.desc = descCreator->createSharedDesc(getOriginalOutputPrecisionAtPort(i), Shape(dims));
|
||||
config.outConfs.push_back(dataConf);
|
||||
}
|
||||
|
||||
config.dynBatchSupport = true;
|
||||
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
||||
}
|
||||
|
||||
|
||||
void MKLDNNIfNode::createPrimitive() {
|
||||
const auto& eng = getEngine();
|
||||
|
||||
for (auto& map_rule : thenInputPortMap) {
|
||||
auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||
auto &toMem = inputMemThen[map_rule.to];
|
||||
|
||||
beforeThenMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||
}
|
||||
|
||||
for (auto& map_rule : elseInputPortMap) {
|
||||
auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||
auto &toMem = inputMemElse[map_rule.to];
|
||||
|
||||
beforeElseMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||
}
|
||||
|
||||
for (auto& map_rule : thenOutputPortMap) {
|
||||
auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||
auto &fromMem = outputMemThen[map_rule.to];
|
||||
|
||||
afterThenMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||
}
|
||||
|
||||
for (auto& map_rule : elseOutputPortMap) {
|
||||
auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||
auto &fromMem = outputMemElse[map_rule.to];
|
||||
|
||||
afterElseMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNIfNode::execute(mkldnn::stream strm) {
|
||||
const bool condition = *(reinterpret_cast<const bool*>(getParentEdgeAt(0)->getMemoryPtr()->GetPtr()));
|
||||
|
||||
if (condition) {
|
||||
for (auto &mapper : beforeThenMappers)
|
||||
mapper->execute(strm);
|
||||
subGraphThen.ResetInferCount();
|
||||
subGraphThen.Infer();
|
||||
for (auto &mapper : afterThenMappers)
|
||||
mapper->execute(strm);
|
||||
} else {
|
||||
for (auto &mapper : beforeElseMappers)
|
||||
mapper->execute(strm);
|
||||
subGraphElse.ResetInferCount();
|
||||
subGraphElse.Infer();
|
||||
for (auto &mapper : afterElseMappers)
|
||||
mapper->execute(strm);
|
||||
}
|
||||
}
|
||||
|
||||
bool MKLDNNIfNode::created() const {
|
||||
return getType() == If;
|
||||
}
|
||||
|
||||
REG_MKLDNN_PRIM_FOR(MKLDNNIfNode, If);
|
66
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h
Normal file
66
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mkldnn_node.h>
|
||||
#include <mkldnn_graph.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
class MKLDNNIfNode : public MKLDNNNode {
|
||||
public:
|
||||
MKLDNNIfNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void getSupportedDescriptors() override;
|
||||
void createPrimitive() override;
|
||||
bool created() const override;
|
||||
void execute(mkldnn::stream strm) override;
|
||||
|
||||
void inline setExtManager(const MKLDNNExtensionManager::Ptr& extMgr) { ext_mng = extMgr; }
|
||||
|
||||
struct PortMap {
|
||||
int from; /**< Index of external/internal out data */
|
||||
int to; /**< Index of external/internal in data */
|
||||
};
|
||||
|
||||
class PortMapHelper {
|
||||
public:
|
||||
PortMapHelper(const MKLDNNMemoryPtr& from, const MKLDNNMemoryPtr& to, const mkldnn::engine& eng);
|
||||
virtual ~PortMapHelper() = default;
|
||||
virtual void execute(mkldnn::stream& strm);
|
||||
protected:
|
||||
mkldnn::reorder reorder;
|
||||
mkldnn::memory mem_holder_src;
|
||||
mkldnn::memory mem_holder_dst;
|
||||
};
|
||||
|
||||
private:
|
||||
MKLDNNExtensionManager::Ptr ext_mng;
|
||||
MKLDNNGraph subGraphThen;
|
||||
MKLDNNGraph subGraphElse;
|
||||
std::deque<MKLDNNMemoryPtr> inputMemThen, inputMemElse, outputMemThen, outputMemElse;
|
||||
|
||||
std::vector<std::shared_ptr<PortMapHelper>>
|
||||
beforeThenMappers,
|
||||
beforeElseMappers,
|
||||
afterThenMappers,
|
||||
afterElseMappers;
|
||||
|
||||
std::vector<PortMap>
|
||||
thenInputPortMap,
|
||||
thenOutputPortMap,
|
||||
elseInputPortMap,
|
||||
elseOutputPortMap;
|
||||
|
||||
const std::shared_ptr<ov::Node> ovOp;
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
@ -0,0 +1,57 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/subgraph/simple_if.hpp"
|
||||
|
||||
using namespace SubgraphTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<std::vector<std::vector<size_t>>> inputShapes = {
|
||||
{{5, 7}, {5, 7}},
|
||||
{{30, 20, 10}, {30, 20, 10}}
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::I8,
|
||||
};
|
||||
|
||||
std::vector<bool> conditions = {true, false};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(conditions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
SimpleIfTest::getTestCaseName);
|
||||
|
||||
TEST_P(SimpleIfTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIf2OutTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(conditions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
SimpleIf2OutTest::getTestCaseName);
|
||||
|
||||
TEST_P(SimpleIf2OutTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfNotConstConditionTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(conditions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
SimpleIfNotConstConditionTest::getTestCaseName);
|
||||
|
||||
TEST_P(SimpleIfNotConstConditionTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
|
||||
} // namespace
|
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using SimpleIfParamsTuple = typename std::tuple<
|
||||
std::vector<std::vector<size_t>>, // Input shapes
|
||||
InferenceEngine::Precision, // Network precision
|
||||
bool, // If condition
|
||||
std::string // Device name
|
||||
>;
|
||||
|
||||
class SimpleIfTest:
|
||||
public testing::WithParamInterface<SimpleIfParamsTuple>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<SimpleIfParamsTuple> &obj);
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class SimpleIf2OutTest : public SimpleIfTest {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class SimpleIfNotConstConditionTest : public SimpleIfTest {
|
||||
public:
|
||||
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override;
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
|
||||
bool condition;
|
||||
};
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,142 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/subgraph/simple_if.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
std::string SimpleIfTest::getTestCaseName(const testing::TestParamInfo<SimpleIfParamsTuple> &obj) {
|
||||
std::vector<std::vector<size_t>> inputShapes;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
bool condition;
|
||||
std::string targetName;
|
||||
std::tie(inputShapes, netPrecision, condition, targetName) = obj.param;
|
||||
std::ostringstream results;
|
||||
|
||||
results << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
results << "netPRC=" << netPrecision.name() << "_";
|
||||
results << "Cond=" << condition << "_";
|
||||
results << "targetDevice=" << targetName << "_";
|
||||
return results.str();
|
||||
}
|
||||
|
||||
void SimpleIfTest::SetUp() {
|
||||
std::vector<std::vector<size_t>> inputShapes;
|
||||
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||
bool condition;
|
||||
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
|
||||
|
||||
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||
|
||||
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
|
||||
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||
auto res2 = std::make_shared<ov::op::v0::Result>(p3);
|
||||
|
||||
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1}, ov::ParameterVector{p1, p2});
|
||||
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res2}, ov::ParameterVector{p3});
|
||||
|
||||
auto condOp = ngraph::builder::makeConstant<bool>(ov::element::Type_t::boolean, {1}, {condition});
|
||||
auto ifOp = std::make_shared<ov::op::v8::If>(condOp);
|
||||
ifOp->set_then_body(thenBody);
|
||||
ifOp->set_else_body(elseBody);
|
||||
ifOp->set_input(paramOuts[0], p1, p3);
|
||||
ifOp->set_input(paramOuts[1], p2, nullptr);
|
||||
auto res = ifOp->set_output(res1, res2);
|
||||
|
||||
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(res)};
|
||||
function = std::make_shared<ov::Function>(results, params, "simpleIf");
|
||||
}
|
||||
|
||||
void SimpleIf2OutTest::SetUp() {
|
||||
std::vector<std::vector<size_t>> inputShapes;
|
||||
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||
bool condition;
|
||||
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
|
||||
|
||||
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||
auto p4 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||
|
||||
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
|
||||
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||
auto res2 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||
auto res3 = std::make_shared<ov::op::v0::Result>(p3);
|
||||
auto res4 = std::make_shared<ov::op::v0::Result>(p4);
|
||||
|
||||
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2});
|
||||
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4});
|
||||
|
||||
auto condOp = ngraph::builder::makeConstant<bool>(ov::element::Type_t::boolean, {1}, {condition});
|
||||
auto ifOp = std::make_shared<ov::op::v8::If>(condOp);
|
||||
ifOp->set_then_body(thenBody);
|
||||
ifOp->set_else_body(elseBody);
|
||||
ifOp->set_input(paramOuts[0], p1, p3);
|
||||
ifOp->set_input(paramOuts[1], p2, p4);
|
||||
auto ifRes1 = ifOp->set_output(res1, res3);
|
||||
auto ifRes2 = ifOp->set_output(res2, res4);
|
||||
|
||||
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(ifRes1), std::make_shared<ov::op::v0::Result>(ifRes2)};
|
||||
function = std::make_shared<ov::Function>(results, params, "simpleIf2Out");
|
||||
}
|
||||
|
||||
void SimpleIfNotConstConditionTest::SetUp() {
|
||||
std::vector<std::vector<size_t>> inputShapes;
|
||||
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
ov::ParameterVector params {
|
||||
ngraph::builder::makeParams(ngPrc, {inputShapes[0]})[0],
|
||||
ngraph::builder::makeParams(ngPrc, {inputShapes[1]})[0],
|
||||
ngraph::builder::makeParams(ov::element::boolean, { {"condition", {1}} })[0]
|
||||
};
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
|
||||
|
||||
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||
auto p4 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||
|
||||
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
|
||||
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||
auto res2 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||
auto res3 = std::make_shared<ov::op::v0::Result>(p3);
|
||||
auto res4 = std::make_shared<ov::op::v0::Result>(p4);
|
||||
|
||||
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2});
|
||||
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4});
|
||||
|
||||
auto ifOp = std::make_shared<ov::op::v8::If>(paramOuts[2]);
|
||||
ifOp->set_then_body(thenBody);
|
||||
ifOp->set_else_body(elseBody);
|
||||
ifOp->set_input(paramOuts[0], p1, p3);
|
||||
ifOp->set_input(paramOuts[1], p2, p4);
|
||||
auto ifRes1 = ifOp->set_output(res1, res3);
|
||||
auto ifRes2 = ifOp->set_output(res2, res4);
|
||||
|
||||
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(ifRes1), std::make_shared<ov::op::v0::Result>(ifRes2)};
|
||||
function = std::make_shared<ov::Function>(results, params, "SimpleIfNotConstConditionTest");
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr SimpleIfNotConstConditionTest::GenerateInput(const InferenceEngine::InputInfo& info) const {
|
||||
if (info.name() == "condition") {
|
||||
bool conditionArr[1] = { condition };
|
||||
return FuncTestUtils::createAndFillBlobWithFloatArray(info.getTensorDesc(), conditionArr, 1);
|
||||
} else {
|
||||
return FuncTestUtils::createAndFillBlob(info.getTensorDesc());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -83,7 +83,7 @@ template<typename fromType, typename toType>
|
||||
std::vector<toType> castVector(const std::vector<fromType> &vec) {
|
||||
std::vector<toType> resVec;
|
||||
resVec.reserve(vec.size());
|
||||
for (auto &el : vec) {
|
||||
for (const auto &el : vec) {
|
||||
resVec.push_back(static_cast<toType>(el));
|
||||
}
|
||||
return resVec;
|
||||
|
@ -273,6 +273,13 @@ public:
|
||||
/// \param bodies_results vector of bodies results for one output.
|
||||
/// \return value Output node for bodies_results.
|
||||
virtual Output<Node> set_body_outputs(const ResultVector& bodies_results);
|
||||
///
|
||||
/// \brief Get number of internal sub-graphs
|
||||
///
|
||||
/// \return Number of sub-graphs.
|
||||
virtual size_t get_num_internal_subgraphs() const {
|
||||
return m_bodies.size();
|
||||
}
|
||||
|
||||
MultiSubGraphOp(const MultiSubGraphOp&) = delete;
|
||||
MultiSubGraphOp(MultiSubGraphOp&&) = default;
|
||||
|
@ -45,9 +45,10 @@ bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr<ov::Function> f)
|
||||
}
|
||||
} else {
|
||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||
if (const auto& sub_graph = sub_graph_node->get_function()) {
|
||||
rewritten |= run_on_function(sub_graph);
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
|
||||
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||
rewritten |= run_on_function(sub_graph_node->get_function(sub_graph_ind));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -161,9 +161,10 @@ bool convert_precision(pass::PassBase& pass,
|
||||
for (auto& node : ops) {
|
||||
pass.transformation_callback(node);
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||
is_changed |= convert_function_precision(sub_graph, true);
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
|
||||
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||
is_changed |= convert_function_precision(sub_graph_node->get_function(sub_graph_ind), true);
|
||||
}
|
||||
}
|
||||
is_changed |= convert_node_input_precision(node);
|
||||
@ -226,9 +227,10 @@ precisions_set_t find_all_used_precisions(const std::shared_ptr<ngraph::Function
|
||||
for (const auto& output : node->outputs()) {
|
||||
used_precisions.emplace(output.get_element_type());
|
||||
}
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||
auto sub_graph_precisions = find_all_used_precisions(sub_graph);
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
|
||||
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||
auto sub_graph_precisions = find_all_used_precisions(sub_graph_node->get_function(sub_graph_ind));
|
||||
used_precisions.insert(sub_graph_precisions.begin(), sub_graph_precisions.end());
|
||||
}
|
||||
}
|
||||
|
@ -172,8 +172,10 @@ bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
|
||||
continue;
|
||||
|
||||
// Recursive apply Matchers for sub-graph based nodes
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
|
||||
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||
auto sub_graph = sub_graph_node->get_function(sub_graph_ind);
|
||||
run_on_function(sub_graph);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user