[CPU] If-8 operation implementation. (#7253)
This commit is contained in:
parent
7b1a418bf4
commit
0ed1c24cd2
@ -144,6 +144,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
|||||||
{ "Cosh", Math},
|
{ "Cosh", Math},
|
||||||
{ "Floor", Math},
|
{ "Floor", Math},
|
||||||
{ "HardSigmoid", Math},
|
{ "HardSigmoid", Math},
|
||||||
|
{ "If", If},
|
||||||
{ "Log", Math},
|
{ "Log", Math},
|
||||||
{ "Neg", Math},
|
{ "Neg", Math},
|
||||||
{ "Reciprocal", Math},
|
{ "Reciprocal", Math},
|
||||||
@ -320,6 +321,8 @@ std::string NameFromType(const Type type) {
|
|||||||
return "DetectionOutput";
|
return "DetectionOutput";
|
||||||
case ExperimentalDetectronDetectionOutput:
|
case ExperimentalDetectronDetectionOutput:
|
||||||
return "ExperimentalDetectronDetectionOutput";
|
return "ExperimentalDetectronDetectionOutput";
|
||||||
|
case If:
|
||||||
|
return "If";
|
||||||
case LogSoftmax:
|
case LogSoftmax:
|
||||||
return "LogSoftmax";
|
return "LogSoftmax";
|
||||||
case TopK:
|
case TopK:
|
||||||
|
@ -17,6 +17,7 @@ using VectorDims = std::vector<Dim>;
|
|||||||
enum Type {
|
enum Type {
|
||||||
Unknown,
|
Unknown,
|
||||||
Generic,
|
Generic,
|
||||||
|
If,
|
||||||
Reorder,
|
Reorder,
|
||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include <nodes/mkldnn_matmul_node.h>
|
#include <nodes/mkldnn_matmul_node.h>
|
||||||
#include <nodes/mkldnn_fullyconnected_node.h>
|
#include <nodes/mkldnn_fullyconnected_node.h>
|
||||||
#include <nodes/mkldnn_generic_node.h>
|
#include <nodes/mkldnn_generic_node.h>
|
||||||
|
#include <nodes/mkldnn_if_node.h>
|
||||||
#include <nodes/mkldnn_input_node.h>
|
#include <nodes/mkldnn_input_node.h>
|
||||||
#include <nodes/mkldnn_lrn_node.h>
|
#include <nodes/mkldnn_lrn_node.h>
|
||||||
#include <nodes/mkldnn_pooling_node.h>
|
#include <nodes/mkldnn_pooling_node.h>
|
||||||
@ -1127,10 +1128,16 @@ MKLDNNNode* MKLDNNNode::NodesFactory::create(const std::shared_ptr<ngraph::Node>
|
|||||||
|
|
||||||
// WA-start : TI node requires all attributes to construct internal subgpath
|
// WA-start : TI node requires all attributes to construct internal subgpath
|
||||||
// including extManager, socket and mkldnn::eng.
|
// including extManager, socket and mkldnn::eng.
|
||||||
MKLDNNTensorIteratorNode *ti = dynamic_cast<MKLDNNTensorIteratorNode*>(newNode);
|
if (newNode) {
|
||||||
if (ti != nullptr)
|
if (newNode->getType() == TensorIterator) {
|
||||||
ti->setExtManager(extMgr);
|
if (auto ti = dynamic_cast<MKLDNNTensorIteratorNode*>(newNode))
|
||||||
// WA-end
|
ti->setExtManager(extMgr);
|
||||||
|
} else if (newNode->getType() == If) {
|
||||||
|
if (auto ifNode = dynamic_cast<MKLDNNIfNode*>(newNode))
|
||||||
|
ifNode->setExtManager(extMgr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// // WA-end
|
||||||
|
|
||||||
if (!newNode) {
|
if (!newNode) {
|
||||||
std::string errorDetails;
|
std::string errorDetails;
|
||||||
|
230
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp
Normal file
230
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.cpp
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "mkldnn_if_node.h"
|
||||||
|
|
||||||
|
#include <mkldnn_extension_utils.h>
|
||||||
|
#include <ie_ngraph_utils.hpp>
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace MKLDNNPlugin;
|
||||||
|
|
||||||
|
MKLDNNIfNode::PortMapHelper::PortMapHelper(const MKLDNNMemoryPtr &from, const MKLDNNMemoryPtr &to, const mkldnn::engine& eng) {
|
||||||
|
mem_holder_src = from->GetPrimitive();
|
||||||
|
mem_holder_dst = to->GetPrimitive();
|
||||||
|
reorder = {mem_holder_src, mem_holder_dst};
|
||||||
|
}
|
||||||
|
|
||||||
|
void MKLDNNIfNode::PortMapHelper::execute(mkldnn::stream& strm) {
|
||||||
|
reorder.execute(strm, mem_holder_src, mem_holder_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MKLDNNIfNode::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||||
|
try {
|
||||||
|
if (isDynamicNgraphNode(op)) {
|
||||||
|
errorMessage = "If node doesn't support op with dynamic shapes";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!one_of(op->get_type_info(),
|
||||||
|
ov::op::v8::If::type_info)) {
|
||||||
|
errorMessage = "Not supported If operation version " + std::to_string(op->get_type_info().version) +
|
||||||
|
" with name '" + op->get_friendly_name() + "'. Node If supports only opset8 version.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MKLDNNIfNode::MKLDNNIfNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) :
|
||||||
|
MKLDNNNode(op, eng, cache), ovOp(op) {
|
||||||
|
std::string errorMessage;
|
||||||
|
if (!isSupportedOperation(op, errorMessage)) {
|
||||||
|
IE_THROW(NotImplemented) << errorMessage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MKLDNNIfNode::getSupportedDescriptors() {
|
||||||
|
auto ifOp = ov::as_type_ptr<ov::op::v8::If>(ovOp);
|
||||||
|
|
||||||
|
const std::shared_ptr<const ov::Function>& thenBody = ifOp->get_then_body();
|
||||||
|
const std::shared_ptr<const ov::Function>& elseBody = ifOp->get_else_body();
|
||||||
|
subGraphThen.CreateGraph(thenBody, ext_mng, weightCache);
|
||||||
|
subGraphElse.CreateGraph(elseBody, ext_mng, weightCache);
|
||||||
|
|
||||||
|
const auto &inMapThen = subGraphThen.GetInputNodesMap();
|
||||||
|
for (const auto ¶m : ifOp->get_then_body()->get_parameters()) {
|
||||||
|
auto inNode = inMapThen.find(param->get_friendly_name());
|
||||||
|
if (inNode != inMapThen.end()) {
|
||||||
|
auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr();
|
||||||
|
inputMemThen.push_back(inMem);
|
||||||
|
} else {
|
||||||
|
IE_THROW() << "Then body of node If with name " << getName() << " does not have input with name: "
|
||||||
|
<< param->get_friendly_name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &inMapElse = subGraphElse.GetInputNodesMap();
|
||||||
|
for (const auto ¶m : ifOp->get_else_body()->get_parameters()) {
|
||||||
|
auto inNode = inMapElse.find(param->get_friendly_name());
|
||||||
|
if (inNode != inMapElse.end()) {
|
||||||
|
auto inMem = inNode->second->getChildEdgeAt(0)->getMemoryPtr();
|
||||||
|
inputMemElse.push_back(inMem);
|
||||||
|
} else {
|
||||||
|
IE_THROW() << "Else body of node If with name " << getName() << " does not have input with name: "
|
||||||
|
<< param->get_friendly_name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &outMapThen = subGraphThen.GetOutputNodesMap();
|
||||||
|
for (const auto& out : ifOp->get_then_body()->get_results()) {
|
||||||
|
auto prev = out->get_input_node_shared_ptr(0);
|
||||||
|
std::string inputID = prev->get_friendly_name();
|
||||||
|
if (prev->get_output_size() > 1) {
|
||||||
|
inputID += "." + std::to_string(out->get_input_source_output(0).get_index());
|
||||||
|
}
|
||||||
|
auto outNode = outMapThen.find(inputID);
|
||||||
|
if (outNode != outMapThen.end()) {
|
||||||
|
auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr();
|
||||||
|
outputMemThen.push_back(outMem);
|
||||||
|
} else {
|
||||||
|
IE_THROW() << "Then body of node If with name " << getName() << " does not have output with name: "
|
||||||
|
<< inputID;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &outMapElse = subGraphElse.GetOutputNodesMap();
|
||||||
|
for (const auto& out : ifOp->get_else_body()->get_results()) {
|
||||||
|
auto prev = out->get_input_node_shared_ptr(0);
|
||||||
|
std::string inputID = prev->get_friendly_name();
|
||||||
|
if (prev->get_output_size() > 1) {
|
||||||
|
inputID += "." + std::to_string(out->get_input_source_output(0).get_index());
|
||||||
|
}
|
||||||
|
auto outNode = outMapElse.find(inputID);
|
||||||
|
if (outNode != outMapElse.end()) {
|
||||||
|
auto outMem = outNode->second->getParentEdgeAt(0)->getMemoryPtr();
|
||||||
|
outputMemElse.push_back(outMem);
|
||||||
|
} else {
|
||||||
|
IE_THROW() << "Else body of node If with name " << getName() << " does not have output with name: "
|
||||||
|
<< inputID;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port map: outputs
|
||||||
|
for (const auto& desc : ifOp->get_output_descriptions(0)) {
|
||||||
|
auto body_output_idx = desc->m_body_value_index;
|
||||||
|
thenOutputPortMap.emplace_back(PortMap {
|
||||||
|
static_cast<int>(desc->m_output_index), static_cast<int>(body_output_idx)});
|
||||||
|
}
|
||||||
|
for (const auto& desc : ifOp->get_output_descriptions(1)) {
|
||||||
|
auto body_output_idx = desc->m_body_value_index;
|
||||||
|
elseOutputPortMap.emplace_back(PortMap {
|
||||||
|
static_cast<int>(desc->m_output_index), static_cast<int>(body_output_idx)});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& desc : ifOp->get_input_descriptions(0)) {
|
||||||
|
auto body_input_index = desc->m_body_parameter_index;
|
||||||
|
thenInputPortMap.emplace_back(PortMap {
|
||||||
|
static_cast<int>(desc->m_input_index), static_cast<int>(body_input_index)});
|
||||||
|
}
|
||||||
|
for (const auto& desc : ifOp->get_input_descriptions(1)) {
|
||||||
|
auto body_input_index = desc->m_body_parameter_index;
|
||||||
|
elseInputPortMap.emplace_back(PortMap {
|
||||||
|
static_cast<int>(desc->m_input_index), static_cast<int>(body_input_index)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MKLDNNIfNode::initSupportedPrimitiveDescriptors() {
|
||||||
|
if (!supportedPrimitiveDescriptors.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
NodeConfig config;
|
||||||
|
config.inConfs.reserve(getParentEdges().size());
|
||||||
|
config.outConfs.reserve(getChildEdges().size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < inputShapes.size(); i++) {
|
||||||
|
auto dims = inputShapes[i].getDims();
|
||||||
|
|
||||||
|
PortConfig dataConf {};
|
||||||
|
auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp);
|
||||||
|
dataConf.desc = descCreator->createSharedDesc(getOriginalInputPrecisionAtPort(i), Shape(dims));
|
||||||
|
config.inConfs.emplace_back(dataConf);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < outputShapes.size(); i++) {
|
||||||
|
auto dims = outputShapes[i].getDims();
|
||||||
|
|
||||||
|
PortConfig dataConf {};
|
||||||
|
auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp);
|
||||||
|
dataConf.desc = descCreator->createSharedDesc(getOriginalOutputPrecisionAtPort(i), Shape(dims));
|
||||||
|
config.outConfs.push_back(dataConf);
|
||||||
|
}
|
||||||
|
|
||||||
|
config.dynBatchSupport = true;
|
||||||
|
|
||||||
|
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void MKLDNNIfNode::createPrimitive() {
|
||||||
|
const auto& eng = getEngine();
|
||||||
|
|
||||||
|
for (auto& map_rule : thenInputPortMap) {
|
||||||
|
auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||||
|
auto &toMem = inputMemThen[map_rule.to];
|
||||||
|
|
||||||
|
beforeThenMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& map_rule : elseInputPortMap) {
|
||||||
|
auto &fromMem = getParentEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||||
|
auto &toMem = inputMemElse[map_rule.to];
|
||||||
|
|
||||||
|
beforeElseMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& map_rule : thenOutputPortMap) {
|
||||||
|
auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||||
|
auto &fromMem = outputMemThen[map_rule.to];
|
||||||
|
|
||||||
|
afterThenMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& map_rule : elseOutputPortMap) {
|
||||||
|
auto &toMem = getChildEdgesAtPort(map_rule.from)[0]->getMemoryPtr();
|
||||||
|
auto &fromMem = outputMemElse[map_rule.to];
|
||||||
|
|
||||||
|
afterElseMappers.emplace_back(std::make_shared<PortMapHelper>(fromMem, toMem, eng));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MKLDNNIfNode::execute(mkldnn::stream strm) {
|
||||||
|
const bool condition = *(reinterpret_cast<const bool*>(getParentEdgeAt(0)->getMemoryPtr()->GetPtr()));
|
||||||
|
|
||||||
|
if (condition) {
|
||||||
|
for (auto &mapper : beforeThenMappers)
|
||||||
|
mapper->execute(strm);
|
||||||
|
subGraphThen.ResetInferCount();
|
||||||
|
subGraphThen.Infer();
|
||||||
|
for (auto &mapper : afterThenMappers)
|
||||||
|
mapper->execute(strm);
|
||||||
|
} else {
|
||||||
|
for (auto &mapper : beforeElseMappers)
|
||||||
|
mapper->execute(strm);
|
||||||
|
subGraphElse.ResetInferCount();
|
||||||
|
subGraphElse.Infer();
|
||||||
|
for (auto &mapper : afterElseMappers)
|
||||||
|
mapper->execute(strm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MKLDNNIfNode::created() const {
|
||||||
|
return getType() == If;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_MKLDNN_PRIM_FOR(MKLDNNIfNode, If);
|
66
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h
Normal file
66
inference-engine/src/mkldnn_plugin/nodes/mkldnn_if_node.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <mkldnn_node.h>
|
||||||
|
#include <mkldnn_graph.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace MKLDNNPlugin {
|
||||||
|
|
||||||
|
class MKLDNNIfNode : public MKLDNNNode {
|
||||||
|
public:
|
||||||
|
MKLDNNIfNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
|
||||||
|
|
||||||
|
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||||
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
|
void getSupportedDescriptors() override;
|
||||||
|
void createPrimitive() override;
|
||||||
|
bool created() const override;
|
||||||
|
void execute(mkldnn::stream strm) override;
|
||||||
|
|
||||||
|
void inline setExtManager(const MKLDNNExtensionManager::Ptr& extMgr) { ext_mng = extMgr; }
|
||||||
|
|
||||||
|
struct PortMap {
|
||||||
|
int from; /**< Index of external/internal out data */
|
||||||
|
int to; /**< Index of external/internal in data */
|
||||||
|
};
|
||||||
|
|
||||||
|
class PortMapHelper {
|
||||||
|
public:
|
||||||
|
PortMapHelper(const MKLDNNMemoryPtr& from, const MKLDNNMemoryPtr& to, const mkldnn::engine& eng);
|
||||||
|
virtual ~PortMapHelper() = default;
|
||||||
|
virtual void execute(mkldnn::stream& strm);
|
||||||
|
protected:
|
||||||
|
mkldnn::reorder reorder;
|
||||||
|
mkldnn::memory mem_holder_src;
|
||||||
|
mkldnn::memory mem_holder_dst;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
MKLDNNExtensionManager::Ptr ext_mng;
|
||||||
|
MKLDNNGraph subGraphThen;
|
||||||
|
MKLDNNGraph subGraphElse;
|
||||||
|
std::deque<MKLDNNMemoryPtr> inputMemThen, inputMemElse, outputMemThen, outputMemElse;
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<PortMapHelper>>
|
||||||
|
beforeThenMappers,
|
||||||
|
beforeElseMappers,
|
||||||
|
afterThenMappers,
|
||||||
|
afterElseMappers;
|
||||||
|
|
||||||
|
std::vector<PortMap>
|
||||||
|
thenInputPortMap,
|
||||||
|
thenOutputPortMap,
|
||||||
|
elseInputPortMap,
|
||||||
|
elseOutputPortMap;
|
||||||
|
|
||||||
|
const std::shared_ptr<ov::Node> ovOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace MKLDNNPlugin
|
@ -0,0 +1,57 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "shared_test_classes/subgraph/simple_if.hpp"
|
||||||
|
|
||||||
|
using namespace SubgraphTestsDefinitions;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::vector<std::vector<std::vector<size_t>>> inputShapes = {
|
||||||
|
{{5, 7}, {5, 7}},
|
||||||
|
{{30, 20, 10}, {30, 20, 10}}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||||
|
InferenceEngine::Precision::I8,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<bool> conditions = {true, false};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(inputShapes),
|
||||||
|
::testing::ValuesIn(netPrecisions),
|
||||||
|
::testing::ValuesIn(conditions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
SimpleIfTest::getTestCaseName);
|
||||||
|
|
||||||
|
TEST_P(SimpleIfTest, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIf2OutTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(inputShapes),
|
||||||
|
::testing::ValuesIn(netPrecisions),
|
||||||
|
::testing::ValuesIn(conditions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
SimpleIf2OutTest::getTestCaseName);
|
||||||
|
|
||||||
|
TEST_P(SimpleIf2OutTest, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_If, SimpleIfNotConstConditionTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(inputShapes),
|
||||||
|
::testing::ValuesIn(netPrecisions),
|
||||||
|
::testing::ValuesIn(conditions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
SimpleIfNotConstConditionTest::getTestCaseName);
|
||||||
|
|
||||||
|
TEST_P(SimpleIfNotConstConditionTest, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||||
|
|
||||||
|
namespace SubgraphTestsDefinitions {
|
||||||
|
|
||||||
|
using SimpleIfParamsTuple = typename std::tuple<
|
||||||
|
std::vector<std::vector<size_t>>, // Input shapes
|
||||||
|
InferenceEngine::Precision, // Network precision
|
||||||
|
bool, // If condition
|
||||||
|
std::string // Device name
|
||||||
|
>;
|
||||||
|
|
||||||
|
class SimpleIfTest:
|
||||||
|
public testing::WithParamInterface<SimpleIfParamsTuple>,
|
||||||
|
public LayerTestsUtils::LayerTestsCommon {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<SimpleIfParamsTuple> &obj);
|
||||||
|
protected:
|
||||||
|
void SetUp() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SimpleIf2OutTest : public SimpleIfTest {
|
||||||
|
protected:
|
||||||
|
void SetUp() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SimpleIfNotConstConditionTest : public SimpleIfTest {
|
||||||
|
public:
|
||||||
|
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override;
|
||||||
|
|
||||||
|
bool condition;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,142 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "shared_test_classes/subgraph/simple_if.hpp"
|
||||||
|
#include "ngraph_functions/builders.hpp"
|
||||||
|
|
||||||
|
namespace SubgraphTestsDefinitions {
|
||||||
|
std::string SimpleIfTest::getTestCaseName(const testing::TestParamInfo<SimpleIfParamsTuple> &obj) {
|
||||||
|
std::vector<std::vector<size_t>> inputShapes;
|
||||||
|
InferenceEngine::Precision netPrecision;
|
||||||
|
bool condition;
|
||||||
|
std::string targetName;
|
||||||
|
std::tie(inputShapes, netPrecision, condition, targetName) = obj.param;
|
||||||
|
std::ostringstream results;
|
||||||
|
|
||||||
|
results << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||||
|
results << "netPRC=" << netPrecision.name() << "_";
|
||||||
|
results << "Cond=" << condition << "_";
|
||||||
|
results << "targetDevice=" << targetName << "_";
|
||||||
|
return results.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimpleIfTest::SetUp() {
|
||||||
|
std::vector<std::vector<size_t>> inputShapes;
|
||||||
|
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||||
|
bool condition;
|
||||||
|
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||||
|
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||||
|
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||||
|
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
|
||||||
|
|
||||||
|
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||||
|
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||||
|
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||||
|
|
||||||
|
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
|
||||||
|
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||||
|
auto res2 = std::make_shared<ov::op::v0::Result>(p3);
|
||||||
|
|
||||||
|
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1}, ov::ParameterVector{p1, p2});
|
||||||
|
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res2}, ov::ParameterVector{p3});
|
||||||
|
|
||||||
|
auto condOp = ngraph::builder::makeConstant<bool>(ov::element::Type_t::boolean, {1}, {condition});
|
||||||
|
auto ifOp = std::make_shared<ov::op::v8::If>(condOp);
|
||||||
|
ifOp->set_then_body(thenBody);
|
||||||
|
ifOp->set_else_body(elseBody);
|
||||||
|
ifOp->set_input(paramOuts[0], p1, p3);
|
||||||
|
ifOp->set_input(paramOuts[1], p2, nullptr);
|
||||||
|
auto res = ifOp->set_output(res1, res2);
|
||||||
|
|
||||||
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(res)};
|
||||||
|
function = std::make_shared<ov::Function>(results, params, "simpleIf");
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimpleIf2OutTest::SetUp() {
|
||||||
|
std::vector<std::vector<size_t>> inputShapes;
|
||||||
|
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||||
|
bool condition;
|
||||||
|
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||||
|
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||||
|
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||||
|
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
|
||||||
|
|
||||||
|
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||||
|
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||||
|
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||||
|
auto p4 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||||
|
|
||||||
|
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
|
||||||
|
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||||
|
auto res2 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||||
|
auto res3 = std::make_shared<ov::op::v0::Result>(p3);
|
||||||
|
auto res4 = std::make_shared<ov::op::v0::Result>(p4);
|
||||||
|
|
||||||
|
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2});
|
||||||
|
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4});
|
||||||
|
|
||||||
|
auto condOp = ngraph::builder::makeConstant<bool>(ov::element::Type_t::boolean, {1}, {condition});
|
||||||
|
auto ifOp = std::make_shared<ov::op::v8::If>(condOp);
|
||||||
|
ifOp->set_then_body(thenBody);
|
||||||
|
ifOp->set_else_body(elseBody);
|
||||||
|
ifOp->set_input(paramOuts[0], p1, p3);
|
||||||
|
ifOp->set_input(paramOuts[1], p2, p4);
|
||||||
|
auto ifRes1 = ifOp->set_output(res1, res3);
|
||||||
|
auto ifRes2 = ifOp->set_output(res2, res4);
|
||||||
|
|
||||||
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(ifRes1), std::make_shared<ov::op::v0::Result>(ifRes2)};
|
||||||
|
function = std::make_shared<ov::Function>(results, params, "simpleIf2Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimpleIfNotConstConditionTest::SetUp() {
|
||||||
|
std::vector<std::vector<size_t>> inputShapes;
|
||||||
|
auto netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||||
|
std::tie(inputShapes, netPrecision, condition, targetDevice) = this->GetParam();
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||||
|
ov::ParameterVector params {
|
||||||
|
ngraph::builder::makeParams(ngPrc, {inputShapes[0]})[0],
|
||||||
|
ngraph::builder::makeParams(ngPrc, {inputShapes[1]})[0],
|
||||||
|
ngraph::builder::makeParams(ov::element::boolean, { {"condition", {1}} })[0]
|
||||||
|
};
|
||||||
|
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||||
|
ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(params));
|
||||||
|
|
||||||
|
auto p1 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||||
|
auto p2 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||||
|
auto p3 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[0]));
|
||||||
|
auto p4 = std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::PartialShape(inputShapes[1]));
|
||||||
|
|
||||||
|
auto thenOp = std::make_shared<ov::op::v1::Add>(p1, p2);
|
||||||
|
auto res1 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||||
|
auto res2 = std::make_shared<ov::op::v0::Result>(thenOp);
|
||||||
|
auto res3 = std::make_shared<ov::op::v0::Result>(p3);
|
||||||
|
auto res4 = std::make_shared<ov::op::v0::Result>(p4);
|
||||||
|
|
||||||
|
auto thenBody = std::make_shared<ov::Function>(ov::OutputVector{res1, res2}, ov::ParameterVector{p1, p2});
|
||||||
|
auto elseBody = std::make_shared<ov::Function>(ov::OutputVector{res3, res4}, ov::ParameterVector{p3, p4});
|
||||||
|
|
||||||
|
auto ifOp = std::make_shared<ov::op::v8::If>(paramOuts[2]);
|
||||||
|
ifOp->set_then_body(thenBody);
|
||||||
|
ifOp->set_else_body(elseBody);
|
||||||
|
ifOp->set_input(paramOuts[0], p1, p3);
|
||||||
|
ifOp->set_input(paramOuts[1], p2, p4);
|
||||||
|
auto ifRes1 = ifOp->set_output(res1, res3);
|
||||||
|
auto ifRes2 = ifOp->set_output(res2, res4);
|
||||||
|
|
||||||
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(ifRes1), std::make_shared<ov::op::v0::Result>(ifRes2)};
|
||||||
|
function = std::make_shared<ov::Function>(results, params, "SimpleIfNotConstConditionTest");
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::Blob::Ptr SimpleIfNotConstConditionTest::GenerateInput(const InferenceEngine::InputInfo& info) const {
|
||||||
|
if (info.name() == "condition") {
|
||||||
|
bool conditionArr[1] = { condition };
|
||||||
|
return FuncTestUtils::createAndFillBlobWithFloatArray(info.getTensorDesc(), conditionArr, 1);
|
||||||
|
} else {
|
||||||
|
return FuncTestUtils::createAndFillBlob(info.getTensorDesc());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace SubgraphTestsDefinitions
|
@ -83,7 +83,7 @@ template<typename fromType, typename toType>
|
|||||||
std::vector<toType> castVector(const std::vector<fromType> &vec) {
|
std::vector<toType> castVector(const std::vector<fromType> &vec) {
|
||||||
std::vector<toType> resVec;
|
std::vector<toType> resVec;
|
||||||
resVec.reserve(vec.size());
|
resVec.reserve(vec.size());
|
||||||
for (auto &el : vec) {
|
for (const auto &el : vec) {
|
||||||
resVec.push_back(static_cast<toType>(el));
|
resVec.push_back(static_cast<toType>(el));
|
||||||
}
|
}
|
||||||
return resVec;
|
return resVec;
|
||||||
|
@ -273,6 +273,13 @@ public:
|
|||||||
/// \param bodies_results vector of bodies results for one output.
|
/// \param bodies_results vector of bodies results for one output.
|
||||||
/// \return value Output node for bodies_results.
|
/// \return value Output node for bodies_results.
|
||||||
virtual Output<Node> set_body_outputs(const ResultVector& bodies_results);
|
virtual Output<Node> set_body_outputs(const ResultVector& bodies_results);
|
||||||
|
///
|
||||||
|
/// \brief Get number of internal sub-graphs
|
||||||
|
///
|
||||||
|
/// \return Number of sub-graphs.
|
||||||
|
virtual size_t get_num_internal_subgraphs() const {
|
||||||
|
return m_bodies.size();
|
||||||
|
}
|
||||||
|
|
||||||
MultiSubGraphOp(const MultiSubGraphOp&) = delete;
|
MultiSubGraphOp(const MultiSubGraphOp&) = delete;
|
||||||
MultiSubGraphOp(MultiSubGraphOp&&) = default;
|
MultiSubGraphOp(MultiSubGraphOp&&) = default;
|
||||||
|
@ -45,9 +45,10 @@ bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr<ov::Function> f)
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
|
||||||
if (const auto& sub_graph = sub_graph_node->get_function()) {
|
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||||
rewritten |= run_on_function(sub_graph);
|
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||||
|
rewritten |= run_on_function(sub_graph_node->get_function(sub_graph_ind));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,9 +161,10 @@ bool convert_precision(pass::PassBase& pass,
|
|||||||
for (auto& node : ops) {
|
for (auto& node : ops) {
|
||||||
pass.transformation_callback(node);
|
pass.transformation_callback(node);
|
||||||
// Recursively apply transformation for sub-graph based operations
|
// Recursively apply transformation for sub-graph based operations
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
|
||||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||||
is_changed |= convert_function_precision(sub_graph, true);
|
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||||
|
is_changed |= convert_function_precision(sub_graph_node->get_function(sub_graph_ind), true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is_changed |= convert_node_input_precision(node);
|
is_changed |= convert_node_input_precision(node);
|
||||||
@ -226,9 +227,10 @@ precisions_set_t find_all_used_precisions(const std::shared_ptr<ngraph::Function
|
|||||||
for (const auto& output : node->outputs()) {
|
for (const auto& output : node->outputs()) {
|
||||||
used_precisions.emplace(output.get_element_type());
|
used_precisions.emplace(output.get_element_type());
|
||||||
}
|
}
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
|
||||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||||
auto sub_graph_precisions = find_all_used_precisions(sub_graph);
|
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||||
|
auto sub_graph_precisions = find_all_used_precisions(sub_graph_node->get_function(sub_graph_ind));
|
||||||
used_precisions.insert(sub_graph_precisions.begin(), sub_graph_precisions.end());
|
used_precisions.insert(sub_graph_precisions.begin(), sub_graph_precisions.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -172,8 +172,10 @@ bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Recursive apply Matchers for sub-graph based nodes
|
// Recursive apply Matchers for sub-graph based nodes
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::MultiSubGraphOp>(node)) {
|
||||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
size_t sub_graphs_num = sub_graph_node->get_num_internal_subgraphs();
|
||||||
|
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
|
||||||
|
auto sub_graph = sub_graph_node->get_function(sub_graph_ind);
|
||||||
run_on_function(sub_graph);
|
run_on_function(sub_graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user