[CPU] Avoid inserting additional transpose + reorder after RNN node. (#5921)

This commit is contained in:
Nikolay Shchegolev
2021-08-24 12:28:23 +03:00
committed by GitHub
parent c69425a96a
commit de46168e98
11 changed files with 102 additions and 43 deletions

View File

@@ -17,6 +17,7 @@
#include <nodes/mkldnn_transpose_node.h>
#include "nodes/mkldnn_interpolate_node.h"
#include "nodes/mkldnn_input_node.h"
#include "nodes/mkldnn_rnn.h"
#include "nodes/common/cpu_convert.h"
#include "mkldnn/ie_mkldnn.h"
@@ -132,6 +133,10 @@ void MKLDNNGraphOptimizer::ApplyCommonGraphOptimizations(MKLDNNGraph &graph) {
FuseEltwiseAndSimple(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "reshapeRnnSeq");
reshapeRnnSeq(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveDroppedEdges");
graph.RemoveDroppedEdges();
}
@@ -973,7 +978,7 @@ static bool is_data_dependency(const std::shared_ptr<MKLDNNNode> &parent,
*/
void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNGraph &graph) {
std::vector<MKLDNNNodePtr> &graphNodes = graph.GetNodes();
auto &graphNodes = graph.GetNodes();
auto isFusingSupported = [&](MKLDNNNodePtr conv, MKLDNNNodePtr child) {
return child->getType() == Eltwise &&
@@ -1444,7 +1449,7 @@ void MKLDNNGraphOptimizer::DropDoubleReorders(MKLDNNGraph &graph) {
}
void MKLDNNGraphOptimizer::FuseBroadcastAndEltwise(MKLDNNGraph &graph) {
std::vector<MKLDNNNodePtr>& graphNodes = graph.GetNodes();
auto& graphNodes = graph.GetNodes();
for (auto &graphNode : graphNodes) {
if (graphNode->getType() != Generic
@@ -1816,3 +1821,43 @@ void MKLDNNGraphOptimizer::MergeTransposeAndReorder(MKLDNNGraph &graph) {
}
}
}
void MKLDNNGraphOptimizer::reshapeRnnSeq(MKLDNNGraph &graph) {
auto& graphNodes = graph.GetNodes();
auto isSutableParentNode = [](MKLDNNNodePtr node) {
if (node->type != RNNSeq)
return false;
auto rnnNode = std::dynamic_pointer_cast<MKLDNNRNN>(node);
return rnnNode && !rnnNode->hasNativeOrder() && node->outputShapes[0].getRank() == 4 && node->outputShapes[0].getDims()[1] == 1;
};
for (int i = 0; i < graphNodes.size(); i++) {
auto& parentNode = graphNodes[i];
if (!isSutableParentNode(parentNode)) {
continue;
}
auto childrenEdges = parentNode->getChildEdgesAtPort(0);
auto newRnnOutDims = parentNode->outputShapes[0].getDims();
newRnnOutDims.erase(newRnnOutDims.begin() + 1);
parentNode->outputShapes[0] = Shape{newRnnOutDims};
for (size_t i = 0; i < childrenEdges.size(); i++) {
auto edge = childrenEdges[i];
auto childNode = edge->getChild();
const MKLDNNNodePtr newReshape = std::make_shared<MKLDNNReshapeNode>(
parentNode->getName() + "_abc_a1bc_" + std::to_string(i),
parentNode->outputShapes[0],
childNode->inputShapes[edge->getOutputNum()],
parentNode->getOriginalOutputPrecisionAtPort(0),
graph.getEngine(), graph.weightsCache);
graph.InsertNode(parentNode, childNode, newReshape, edge->getInputNum(), edge->getOutputNum(), false);
edge->drop();
graph.RemoveEdge(edge);
}
}
}

View File

@@ -39,6 +39,7 @@ private:
void FusePerformedAsScaleShiftAndFakeQuantize(MKLDNNGraph &graph);
void FuseClampAndFakeQuantize(MKLDNNGraph &graph);
void MergeTransposeAndReorder(MKLDNNGraph &graph);
void reshapeRnnSeq(MKLDNNGraph &graph);
};
} // namespace MKLDNNPlugin

View File

@@ -56,34 +56,15 @@ namespace {
auto reshape1 = std::make_shared<ngraph::op::v1::Reshape>(in_0, newInShape, false);
ngraph::replace_node(sequenceOp->get_input_node_shared_ptr(0), {reshape1->output(0)});
const auto &gruTargetInputs = sequenceOp->output(0).get_target_inputs();
if (gruTargetInputs.empty())
const auto &seqTargetInputs = sequenceOp->output(0).get_target_inputs();
if (seqTargetInputs.empty())
return false;
auto transposeAfter = gruTargetInputs.begin()->get_node()->shared_from_this();
auto transposeAfter = seqTargetInputs.begin()->get_node()->shared_from_this();
auto newOutShape = ngraph::op::v0::Constant::create(ngraph::element::i32, ngraph::Shape{4}, transposeAfter->get_output_shape(0));
auto reshape2 = std::make_shared<ngraph::op::v1::Reshape>(sequenceOp->output(0), newOutShape, false);
reshape2->set_friendly_name(transposeAfter->get_friendly_name());
ngraph::replace_node(transposeAfter, {reshape2->output(0)});
} else {
auto originShape = sequenceOp->get_output_shape(0);
const auto targetInputs = sequenceOp->get_output_target_inputs(0);
if (targetInputs.empty()) {
return false;
}
auto seqOut = targetInputs.begin()->get_node()->shared_from_this();
auto tncShape = ngraph::op::v0::Constant::create(ngraph::element::i32, ngraph::Shape{3}, {originShape[2], originShape[0], originShape[3]});
auto reshape1 = std::make_shared<ngraph::op::v1::Reshape>(sequenceOp->output(0), tncShape, false);
auto order = ngraph::op::v0::Constant::create(ngraph::element::i32, ngraph::Shape{3}, {1, 0, 2});
auto transpose = std::make_shared<ngraph::op::v1::Transpose>(reshape1->output(0), order);
auto ndtcShape = ngraph::op::v0::Constant::create(ngraph::element::i32, ngraph::Shape{4}, originShape);
auto reshape2 = std::make_shared<ngraph::op::v1::Reshape>(transpose->output(0), ndtcShape, false);
reshape2->set_friendly_name(sequenceOp->get_friendly_name()+".0");
ngraph::insert_new_node_between(sequenceOp, seqOut, reshape2);
}
sequenceOp->get_rt_info()["seqAxis"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(seqAxis);

View File

@@ -14,6 +14,15 @@ using namespace InferenceEngine;
MKLDNNReshapeNode::MKLDNNReshapeNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) :
MKLDNNNode(op, eng, cache) {}
MKLDNNReshapeNode::MKLDNNReshapeNode(const std::string& name, const Shape& inDims, const Shape& outDims, Precision precision,
const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &wCache)
: MKLDNNNode("Reshape", name, eng, wCache) {
this->inputShapes.push_back(inDims);
this->outputShapes.push_back(outDims);
addOriginalInputPrecision(precision);
addOriginalOutputPrecision(precision);
}
void MKLDNNReshapeNode::getSupportedDescriptors() {
if (getParentEdges().size() != 1 && getParentEdges().size() != 2)
IE_THROW() << "Incorrect number of input edges for layer " << getName();

View File

@@ -15,6 +15,12 @@ namespace MKLDNNPlugin {
class MKLDNNReshapeNode : public MKLDNNNode {
public:
MKLDNNReshapeNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
MKLDNNReshapeNode(const std::string& name,
const Shape& inDims,
const Shape& outDims,
InferenceEngine::Precision precision,
const mkldnn::engine& eng,
MKLDNNWeightsSharing::Ptr &wCache);
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;

View File

@@ -412,6 +412,9 @@ void MKLDNNRNN::fillSeqDesc() {
if (nativeOrder)
in_candidate.emplace_back(inputShapes[RNNInOutKind::Layer].getStaticDims(), dataType, memory::format_tag::tnc);
else if (N == 1)
// WA to avoid reorder before sequence for some models
in_candidate.emplace_back(std::vector<size_t>{N, T, DC}, dataType, memory::format_tag::tnc);
else
in_candidate.emplace_back(std::vector<size_t>{N, T, DC}, dataType, memory::format_tag::ntc);
@@ -428,9 +431,11 @@ void MKLDNNRNN::fillSeqDesc() {
if (nativeOrder) {
out_candidate.emplace_back(out_data_d[RNNInOutKind::Layer]);
} else if (N == 1) {
// WA to avoid reorder after sequence for some models
out_candidate.emplace_back(std::vector<size_t>{N, T, SC}, dataType, memory::format_tag::tnc);
} else {
// TODO reorder ntc -> ndtc does not work, thus use tnc(plain) + transformation reshape-transpose-reshape for now.
out_candidate.emplace_back(std::vector<size_t>{T, N, SC}, dataType, memory::format_tag::tnc);
out_candidate.emplace_back(std::vector<size_t>{N, T, SC}, dataType, memory::format_tag::ntc);
}
out_candidate.emplace_back(std::vector<size_t>{N, D, SC}, dataType, memory::format_tag::ntc);

View File

@@ -24,6 +24,10 @@ public:
void execute(mkldnn::stream strm) override;
inline bool hasNativeOrder() const {
return nativeOrder;
}
private:
void initCell(const std::shared_ptr<ngraph::Node>& op);
void initSeq(const std::shared_ptr<ngraph::Node>& op);

View File

@@ -115,13 +115,10 @@ protected:
// returned output format always tnc
if (ngraph::shape_size(gru_sequence->get_output_shape(0)) == 1) {
outFmts[0] = tnc;
} else if (ngraph::shape_size(gru_sequence->get_output_shape(1)) == 1) {
} else if (ngraph::shape_size(gru_sequence->get_output_shape(1)) == 1 ||
gru_sequence->get_output_shape(0)[0] == 1) {
outFmts[1] = tnc;
}
// if output format equals for all outputs, runtime info return only one formats
if (outFmts[0] == outFmts[1]) {
outFmts.erase(outFmts.begin());
}
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
@@ -170,8 +167,8 @@ namespace {
std::vector<std::map<std::string, std::string>> additionalConfig
= {{{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::NO}}, {{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}}};
CPUSpecificParams cpuParams{{ntc, ntc}, {tnc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParamsBatchSizeOne{{ntc, ntc}, {tnc, ntc}, {"ref_any"}, "ref_any"};;
CPUSpecificParams cpuParams{{ntc, ntc}, {ntc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParamsBatchSizeOne{{tnc, ntc}, {tnc, ntc}, {"ref_any"}, "ref_any"};;
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lengths = 2

View File

@@ -119,15 +119,12 @@ protected:
// returned output format always tnc
if (outFmts.size() >= 3) {
for (size_t i = 1; i < 3; i++) {
if (ngraph::shape_size(lstm_sequence->get_output_shape(i)) == 1) {
if (ngraph::shape_size(lstm_sequence->get_output_shape(i)) == 1 ||
lstm_sequence->get_output_shape(0) == ngraph::Shape{1, 1, 2, 10}) {
outFmts[i] = tnc;
}
}
}
// if output format equals for all outputs, runtime info return only one formats
if (std::adjacent_find(outFmts.begin(), outFmts.end(), std::not_equal_to<cpu_memory_format_t>()) == outFmts.end()) {
outFmts.resize(1);
}
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
@@ -179,8 +176,8 @@ std::vector<std::map<std::string, std::string>> additionalConfig
= {{{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::NO}},
{{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}}};
CPUSpecificParams cpuParams{{ntc, ntc, ntc}, {tnc, ntc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParamsBatchSizeOne{{ntc, ntc, ntc}, {tnc, ntc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParams{{ntc, ntc, ntc}, {ntc, ntc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParamsBatchSizeOne{{tnc, ntc, ntc}, {tnc, ntc, ntc}, {"ref_any"}, "ref_any"};
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::PURE_SEQ};
std::vector<size_t> seq_lengths_zero_clip{2};

View File

@@ -148,8 +148,8 @@ namespace {
std::vector<std::map<std::string, std::string>> additionalConfig
= {{{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::NO}}, {{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}}};
CPUSpecificParams cpuParams{{ntc, ntc}, {tnc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParamsBatchSizeOne{{ntc, ntc}, {tnc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParams{{ntc, ntc}, {ntc, ntc}, {"ref_any"}, "ref_any"};
CPUSpecificParams cpuParamsBatchSizeOne{{tnc, ntc}, {tnc, tnc}, {"ref_any"}, "ref_any"};
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lengths = 2

View File

@@ -176,7 +176,21 @@ void CPUTestsBase::CheckPluginRelatedResults(InferenceEngine::ExecutableNetwork
auto actualOutputMemoryFormats = getActualOutputMemoryFormats(getExecValueOutputsLayout(node));
for (size_t i = 0; i < outFmts.size(); i++) {
bool isAllEqual = true;
for (size_t i = 1; i < outFmts.size(); i++) {
if (outFmts[i - 1] != outFmts[i]) {
isAllEqual = false;
break;
}
}
size_t fmtsNum = outFmts.size();
if (isAllEqual) {
fmtsNum = fmtsNum == 0 ? 0 : 1;
} else {
ASSERT_EQ(fmtsNum, actualOutputMemoryFormats.size());
}
for (size_t i = 0; i < fmtsNum; i++) {
const auto actualOutputMemoryFormat = getExecValue(ExecGraphInfoSerialization::OUTPUT_LAYOUTS);
const auto shape = node->get_output_shape(i);