[CPU] fixed MergePermuteAndReorder optimization (#3318)
This commit is contained in:
parent
843d5de611
commit
627c5e6d0e
@ -1083,7 +1083,7 @@ void MKLDNNGraph::RemoveDroppedEdges() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const TensorDesc& inDesc, const TensorDesc& outDesc,
|
MKLDNNNodePtr MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const TensorDesc& inDesc, const TensorDesc& outDesc,
|
||||||
bool isOptimized, InferenceEngine::Blob::Ptr scales) {
|
bool isOptimized, InferenceEngine::Blob::Ptr scales) {
|
||||||
CNNLayerPtr layer(new CNNLayer({layerName,
|
CNNLayerPtr layer(new CNNLayer({layerName,
|
||||||
"Reorder",
|
"Reorder",
|
||||||
@ -1133,6 +1133,7 @@ void MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const
|
|||||||
}
|
}
|
||||||
|
|
||||||
graphNodes.push_back(newReorder);
|
graphNodes.push_back(newReorder);
|
||||||
|
return newReorder;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MKLDNNGraph::dumpToDotFile(std::string file) const {
|
void MKLDNNGraph::dumpToDotFile(std::string file) const {
|
||||||
|
@ -109,10 +109,10 @@ public:
|
|||||||
* optimization flag; if isOptimized is true then Reorder node does nothing
|
* optimization flag; if isOptimized is true then Reorder node does nothing
|
||||||
* @param scales
|
* @param scales
|
||||||
* pointer to the blob containing scales
|
* pointer to the blob containing scales
|
||||||
* @return none.
|
* @return pointer to the new Reorder node.
|
||||||
*/
|
*/
|
||||||
void InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const InferenceEngine::TensorDesc& inDesc, const InferenceEngine::TensorDesc& outDesc,
|
MKLDNNNodePtr InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const InferenceEngine::TensorDesc& inDesc,
|
||||||
bool isOptimized = false, InferenceEngine::Blob::Ptr scales = nullptr);
|
const InferenceEngine::TensorDesc& outDesc, bool isOptimized = false, InferenceEngine::Blob::Ptr scales = nullptr);
|
||||||
|
|
||||||
InferenceEngine::CNNNetwork dump() const;
|
InferenceEngine::CNNNetwork dump() const;
|
||||||
|
|
||||||
|
@ -2312,8 +2312,8 @@ void MKLDNNGraphOptimizer::MergePermuteAndReorder(MKLDNNGraph &graph) {
|
|||||||
graph.DropNode(parentNode);
|
graph.DropNode(parentNode);
|
||||||
graph.DropNode(childNode);
|
graph.DropNode(childNode);
|
||||||
|
|
||||||
auto inDesc = parentParentNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
|
auto inDesc = parentNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
|
||||||
auto outDesc = childChildNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
|
auto outDesc = childNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
|
||||||
|
|
||||||
auto inPrec = inDesc.getPrecision();
|
auto inPrec = inDesc.getPrecision();
|
||||||
auto outPrec = outDesc.getPrecision();
|
auto outPrec = outDesc.getPrecision();
|
||||||
@ -2333,13 +2333,12 @@ void MKLDNNGraphOptimizer::MergePermuteAndReorder(MKLDNNGraph &graph) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.InsertReorder(edge, reorderlayerName, reorderInDesc, reorderOutDesc, true);
|
auto reorderNode = graph.InsertReorder(edge, reorderlayerName, reorderInDesc, reorderOutDesc, true);
|
||||||
|
|
||||||
// case 2
|
// case 2
|
||||||
if (inPrec != outPrec) {
|
if (inPrec != outPrec) {
|
||||||
auto reorderNode = parentParentNode->getChildEdgeAt(0)->getChild();
|
auto reorderInDesc2 = TensorDesc(reorderOutDesc);
|
||||||
auto reorderInDesc2 = TensorDesc(reorderNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc);
|
auto reorderOutDesc2 = TensorDesc(outDesc);
|
||||||
auto reorderOutDesc2 = TensorDesc(childChildNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc);
|
|
||||||
|
|
||||||
std::string reorderLayerName2 = reorderNode->getName() + "_" +
|
std::string reorderLayerName2 = reorderNode->getName() + "_" +
|
||||||
MKLDNNExtensionUtils::getReorderArgs(reorderInDesc2, reorderOutDesc2) + "_" + childChildNode->getName();
|
MKLDNNExtensionUtils::getReorderArgs(reorderInDesc2, reorderOutDesc2) + "_" + childChildNode->getName();
|
||||||
|
@ -18,8 +18,8 @@ using namespace CPUTestUtils;
|
|||||||
namespace LayerTestsDefinitions {
|
namespace LayerTestsDefinitions {
|
||||||
|
|
||||||
using FusePermuteAndReorderParams = std::tuple<
|
using FusePermuteAndReorderParams = std::tuple<
|
||||||
InferenceEngine::SizeVector, // Input shape
|
InferenceEngine::SizeVector, // Input shape
|
||||||
InferenceEngine::Precision // Input precision
|
InferenceEngine::Precision // Input precision
|
||||||
>;
|
>;
|
||||||
|
|
||||||
class FusePermuteAndReorderTest : public testing::WithParamInterface<FusePermuteAndReorderParams>, public CPUTestsBase,
|
class FusePermuteAndReorderTest : public testing::WithParamInterface<FusePermuteAndReorderParams>, public CPUTestsBase,
|
||||||
@ -29,7 +29,21 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
std::string pluginTypeNode;
|
virtual void CreateGraph();
|
||||||
|
void CheckPermuteCount(size_t expectedPermuteCount);
|
||||||
|
|
||||||
|
InferenceEngine::SizeVector inputShape;
|
||||||
|
InferenceEngine::Precision inPrec;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FusePermuteAndReorderTest1 : public FusePermuteAndReorderTest {
|
||||||
|
protected:
|
||||||
|
void CreateGraph() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FusePermuteAndReorderTest2 : public FusePermuteAndReorderTest {
|
||||||
|
protected:
|
||||||
|
void CreateGraph() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace LayerTestsDefinitions
|
} // namespace LayerTestsDefinitions
|
||||||
|
@ -21,40 +21,11 @@ std::string FusePermuteAndReorderTest::getTestCaseName(testing::TestParamInfo<Fu
|
|||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusePermuteAndReorderTest::SetUp() {
|
void FusePermuteAndReorderTest::CheckPermuteCount(size_t expectedPermuteCount) {
|
||||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
|
||||||
SizeVector inputShape;
|
|
||||||
Precision inPrec;
|
|
||||||
|
|
||||||
std::tie(inputShape, inPrec) = this->GetParam();
|
|
||||||
|
|
||||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
|
||||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
|
||||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
|
||||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
|
||||||
|
|
||||||
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
|
|
||||||
auto memFmt = inputShape.size() == 5 ? ndhwc : nhwc;
|
|
||||||
|
|
||||||
auto constOrder = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
|
||||||
|
|
||||||
auto permute = std::make_shared<ngraph::opset5::Transpose>(paramOuts[0], constOrder);
|
|
||||||
|
|
||||||
permute->get_rt_info() = setCPUInfo({memFmt}, {memFmt}, {});
|
|
||||||
|
|
||||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(permute)};
|
|
||||||
function = std::make_shared<ngraph::Function>(results, params, "PermuteReorder");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
|
|
||||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
|
||||||
|
|
||||||
Run();
|
|
||||||
|
|
||||||
InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
|
InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
|
||||||
auto function = execGraphInfo.getFunction();
|
auto function = execGraphInfo.getFunction();
|
||||||
ASSERT_NE(nullptr, function);
|
ASSERT_NE(nullptr, function);
|
||||||
bool permuteFound = false;
|
size_t actualPermuteCount = 0;
|
||||||
for (const auto &node : function->get_ops()) {
|
for (const auto &node : function->get_ops()) {
|
||||||
const auto & rtInfo = node->get_rt_info();
|
const auto & rtInfo = node->get_rt_info();
|
||||||
auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
|
auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
|
||||||
@ -65,18 +36,204 @@ TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
|
|||||||
return value->get();
|
return value->get();
|
||||||
};
|
};
|
||||||
if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == "Permute") {
|
if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == "Permute") {
|
||||||
permuteFound = true;
|
actualPermuteCount++;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ASSERT_TRUE(!permuteFound);
|
|
||||||
|
ASSERT_EQ(expectedPermuteCount, actualPermuteCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto fusePermuteAndReorderParams = ::testing::Combine(
|
void FusePermuteAndReorderTest::SetUp() {
|
||||||
|
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||||
|
|
||||||
|
std::tie(inputShape, inPrec) = this->GetParam();
|
||||||
|
CreateGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto fusePermuteAndReorderCommonParams = ::testing::Combine(
|
||||||
::testing::Values(SizeVector{1, 2, 3, 4}, SizeVector{1, 2, 3, 4, 5}),
|
::testing::Values(SizeVector{1, 2, 3, 4}, SizeVector{1, 2, 3, 4, 5}),
|
||||||
::testing::Values(Precision::I8, Precision::U8)
|
::testing::Values(Precision::I8, Precision::U8)
|
||||||
);
|
);
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest, fusePermuteAndReorderParams, FusePermuteAndReorderTest::getTestCaseName);
|
/* FusePermuteAndReorderTest graph
|
||||||
|
---------
|
||||||
|
|Input |
|
||||||
|
---------
|
||||||
|
|
|
||||||
|
-------------
|
||||||
|
| --------- |
|
||||||
|
| |Permute| |
|
||||||
|
| --------- |
|
||||||
|
| | |
|
||||||
|
| --------- |
|
||||||
|
| |Reorder| |
|
||||||
|
| --------- |
|
||||||
|
|-----------|
|
||||||
|
|
|
||||||
|
---------
|
||||||
|
|Output |
|
||||||
|
---------
|
||||||
|
*/
|
||||||
|
|
||||||
|
void FusePermuteAndReorderTest::CreateGraph() {
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||||
|
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||||
|
|
||||||
|
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
|
||||||
|
auto memFmt = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||||
|
|
||||||
|
auto constOrder = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||||
|
auto permute = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder);
|
||||||
|
permute->get_rt_info() = setCPUInfo({memFmt}, {memFmt}, {});
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(permute)};
|
||||||
|
function = std::make_shared<ngraph::Function>(results, params, "PermuteReorder");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
Run();
|
||||||
|
CheckPermuteCount(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||||
|
|
||||||
|
|
||||||
|
/* FusePermuteAndReorderTest1 graph
|
||||||
|
---------
|
||||||
|
|Input |
|
||||||
|
---------
|
||||||
|
|
|
||||||
|
---------
|
||||||
|
|Permute|
|
||||||
|
---------
|
||||||
|
|
|
||||||
|
-------------------
|
||||||
|
| |
|
||||||
|
| -------------
|
||||||
|
| | --------- |
|
||||||
|
| | |Permute| |
|
||||||
|
--------- | --------- |
|
||||||
|
|Reshape| | | |
|
||||||
|
--------- | --------- |
|
||||||
|
| | |Reorder| |
|
||||||
|
| | --------- |
|
||||||
|
| |-----------|
|
||||||
|
| |
|
||||||
|
| ---------
|
||||||
|
| |Permute|
|
||||||
|
| ---------
|
||||||
|
| |
|
||||||
|
-------- --------
|
||||||
|
| |
|
||||||
|
---------
|
||||||
|
|Concat |
|
||||||
|
---------
|
||||||
|
|
|
||||||
|
---------
|
||||||
|
|Output |
|
||||||
|
---------
|
||||||
|
*/
|
||||||
|
|
||||||
|
void FusePermuteAndReorderTest1::CreateGraph() {
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||||
|
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||||
|
|
||||||
|
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
|
||||||
|
|
||||||
|
auto constOrder1 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||||
|
auto permute1 = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder1);
|
||||||
|
auto memFmt1 = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||||
|
permute1->get_rt_info() = setCPUInfo({memFmt1}, {memFmt1}, {});
|
||||||
|
|
||||||
|
auto constOrder2 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||||
|
auto permute2 = std::make_shared<ngraph::opset5::Transpose>(permute1, constOrder2);
|
||||||
|
auto memFmt2 = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||||
|
permute2->get_rt_info() = setCPUInfo({memFmt2}, {memFmt2}, {});
|
||||||
|
|
||||||
|
auto constOrder3 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||||
|
auto permute3 = std::make_shared<ngraph::opset5::Transpose>(permute2, constOrder3);
|
||||||
|
auto memFmt3 = inputShape.size() == 5 ? ncdhw : nchw;
|
||||||
|
permute3->get_rt_info() = setCPUInfo({memFmt3}, {memFmt3}, {});
|
||||||
|
|
||||||
|
auto shape = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, permute3->get_output_shape(0));
|
||||||
|
auto reshape = std::make_shared<ngraph::opset5::Reshape>(permute1, shape, false);
|
||||||
|
|
||||||
|
auto concat = ngraph::builder::makeConcat({permute3, reshape}, 1);
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(concat)};
|
||||||
|
function = std::make_shared<ngraph::Function>(results, params, "Permute_PermuteReorderPermute_Reshape_Concat");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(FusePermuteAndReorderTest1, CompareWithRefs) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
Run();
|
||||||
|
CheckPermuteCount(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest1, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||||
|
|
||||||
|
|
||||||
|
/* FusePermuteAndReorderTest2 graph
|
||||||
|
--------- ---------
|
||||||
|
|Input | |Input |
|
||||||
|
--------- ---------
|
||||||
|
| |
|
||||||
|
| -------------
|
||||||
|
--------- | --------- |
|
||||||
|
|Reorder| | |Permute| |
|
||||||
|
--------- | --------- |
|
||||||
|
| | | |
|
||||||
|
--------- | --------- |
|
||||||
|
|Permute| | |Reorder| |
|
||||||
|
--------- | --------- |
|
||||||
|
| |-----------|
|
||||||
|
| |
|
||||||
|
-------- --------
|
||||||
|
| |
|
||||||
|
---------
|
||||||
|
|Concat |
|
||||||
|
---------
|
||||||
|
|
|
||||||
|
---------
|
||||||
|
|Output |
|
||||||
|
---------
|
||||||
|
*/
|
||||||
|
|
||||||
|
void FusePermuteAndReorderTest2::CreateGraph() {
|
||||||
|
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||||
|
|
||||||
|
auto inputShape2(inputShape);
|
||||||
|
inputShape2[inputShape2.size() - 1] *= 2;
|
||||||
|
auto params = ngraph::builder::makeParams(ngPrc, {inputShape, inputShape2});
|
||||||
|
|
||||||
|
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 4, 1, 2, 3} : std::vector<int64_t>{0, 3, 1, 2};
|
||||||
|
|
||||||
|
auto constOrder1 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||||
|
auto permute1 = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder1);
|
||||||
|
auto memFmt1 = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||||
|
permute1->get_rt_info() = setCPUInfo({memFmt1}, {memFmt1}, {});
|
||||||
|
|
||||||
|
auto constOrder2 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||||
|
auto permute2 = std::make_shared<ngraph::opset5::Transpose>(params[1], constOrder2);
|
||||||
|
auto memFmt2 = inputShape.size() == 5 ? ncdhw : nchw;
|
||||||
|
permute2->get_rt_info() = setCPUInfo({memFmt2}, {memFmt2}, {});
|
||||||
|
|
||||||
|
auto concat = ngraph::builder::makeConcat({permute1, permute2}, 1);
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(concat)};
|
||||||
|
function = std::make_shared<ngraph::Function>(results, params, "Permute_Permute_Concat");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(FusePermuteAndReorderTest2, CompareWithRefs) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
Run();
|
||||||
|
CheckPermuteCount(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest2, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||||
|
|
||||||
} // namespace LayerTestsDefinitions
|
} // namespace LayerTestsDefinitions
|
||||||
|
Loading…
Reference in New Issue
Block a user