[CPU] fixed MergePermuteAndReorder optimization (#3318)
This commit is contained in:
parent
843d5de611
commit
627c5e6d0e
@ -1083,7 +1083,7 @@ void MKLDNNGraph::RemoveDroppedEdges() {
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const TensorDesc& inDesc, const TensorDesc& outDesc,
|
||||
MKLDNNNodePtr MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const TensorDesc& inDesc, const TensorDesc& outDesc,
|
||||
bool isOptimized, InferenceEngine::Blob::Ptr scales) {
|
||||
CNNLayerPtr layer(new CNNLayer({layerName,
|
||||
"Reorder",
|
||||
@ -1133,6 +1133,7 @@ void MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const
|
||||
}
|
||||
|
||||
graphNodes.push_back(newReorder);
|
||||
return newReorder;
|
||||
}
|
||||
|
||||
void MKLDNNGraph::dumpToDotFile(std::string file) const {
|
||||
|
@ -109,10 +109,10 @@ public:
|
||||
* optimization flag; if isOptimized is true then Reorder node does nothing
|
||||
* @param scales
|
||||
* pointer to the blob containing scales
|
||||
* @return none.
|
||||
* @return pointer to the new Reorder node.
|
||||
*/
|
||||
void InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const InferenceEngine::TensorDesc& inDesc, const InferenceEngine::TensorDesc& outDesc,
|
||||
bool isOptimized = false, InferenceEngine::Blob::Ptr scales = nullptr);
|
||||
MKLDNNNodePtr InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const InferenceEngine::TensorDesc& inDesc,
|
||||
const InferenceEngine::TensorDesc& outDesc, bool isOptimized = false, InferenceEngine::Blob::Ptr scales = nullptr);
|
||||
|
||||
InferenceEngine::CNNNetwork dump() const;
|
||||
|
||||
|
@ -2312,8 +2312,8 @@ void MKLDNNGraphOptimizer::MergePermuteAndReorder(MKLDNNGraph &graph) {
|
||||
graph.DropNode(parentNode);
|
||||
graph.DropNode(childNode);
|
||||
|
||||
auto inDesc = parentParentNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
|
||||
auto outDesc = childChildNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
|
||||
auto inDesc = parentNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
|
||||
auto outDesc = childNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
|
||||
|
||||
auto inPrec = inDesc.getPrecision();
|
||||
auto outPrec = outDesc.getPrecision();
|
||||
@ -2333,13 +2333,12 @@ void MKLDNNGraphOptimizer::MergePermuteAndReorder(MKLDNNGraph &graph) {
|
||||
}
|
||||
}
|
||||
|
||||
graph.InsertReorder(edge, reorderlayerName, reorderInDesc, reorderOutDesc, true);
|
||||
auto reorderNode = graph.InsertReorder(edge, reorderlayerName, reorderInDesc, reorderOutDesc, true);
|
||||
|
||||
// case 2
|
||||
if (inPrec != outPrec) {
|
||||
auto reorderNode = parentParentNode->getChildEdgeAt(0)->getChild();
|
||||
auto reorderInDesc2 = TensorDesc(reorderNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc);
|
||||
auto reorderOutDesc2 = TensorDesc(childChildNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc);
|
||||
auto reorderInDesc2 = TensorDesc(reorderOutDesc);
|
||||
auto reorderOutDesc2 = TensorDesc(outDesc);
|
||||
|
||||
std::string reorderLayerName2 = reorderNode->getName() + "_" +
|
||||
MKLDNNExtensionUtils::getReorderArgs(reorderInDesc2, reorderOutDesc2) + "_" + childChildNode->getName();
|
||||
|
@ -18,8 +18,8 @@ using namespace CPUTestUtils;
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using FusePermuteAndReorderParams = std::tuple<
|
||||
InferenceEngine::SizeVector, // Input shape
|
||||
InferenceEngine::Precision // Input precision
|
||||
InferenceEngine::SizeVector, // Input shape
|
||||
InferenceEngine::Precision // Input precision
|
||||
>;
|
||||
|
||||
class FusePermuteAndReorderTest : public testing::WithParamInterface<FusePermuteAndReorderParams>, public CPUTestsBase,
|
||||
@ -29,7 +29,21 @@ public:
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
std::string pluginTypeNode;
|
||||
virtual void CreateGraph();
|
||||
void CheckPermuteCount(size_t expectedPermuteCount);
|
||||
|
||||
InferenceEngine::SizeVector inputShape;
|
||||
InferenceEngine::Precision inPrec;
|
||||
};
|
||||
|
||||
class FusePermuteAndReorderTest1 : public FusePermuteAndReorderTest {
|
||||
protected:
|
||||
void CreateGraph() override;
|
||||
};
|
||||
|
||||
class FusePermuteAndReorderTest2 : public FusePermuteAndReorderTest {
|
||||
protected:
|
||||
void CreateGraph() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -21,40 +21,11 @@ std::string FusePermuteAndReorderTest::getTestCaseName(testing::TestParamInfo<Fu
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void FusePermuteAndReorderTest::SetUp() {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
SizeVector inputShape;
|
||||
Precision inPrec;
|
||||
|
||||
std::tie(inputShape, inPrec) = this->GetParam();
|
||||
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||
|
||||
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
|
||||
auto memFmt = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||
|
||||
auto constOrder = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
|
||||
auto permute = std::make_shared<ngraph::opset5::Transpose>(paramOuts[0], constOrder);
|
||||
|
||||
permute->get_rt_info() = setCPUInfo({memFmt}, {memFmt}, {});
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(permute)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "PermuteReorder");
|
||||
}
|
||||
|
||||
TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
|
||||
void FusePermuteAndReorderTest::CheckPermuteCount(size_t expectedPermuteCount) {
|
||||
InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
|
||||
auto function = execGraphInfo.getFunction();
|
||||
ASSERT_NE(nullptr, function);
|
||||
bool permuteFound = false;
|
||||
size_t actualPermuteCount = 0;
|
||||
for (const auto &node : function->get_ops()) {
|
||||
const auto & rtInfo = node->get_rt_info();
|
||||
auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
|
||||
@ -65,18 +36,204 @@ TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
|
||||
return value->get();
|
||||
};
|
||||
if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == "Permute") {
|
||||
permuteFound = true;
|
||||
break;
|
||||
actualPermuteCount++;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(!permuteFound);
|
||||
|
||||
ASSERT_EQ(expectedPermuteCount, actualPermuteCount);
|
||||
}
|
||||
|
||||
const auto fusePermuteAndReorderParams = ::testing::Combine(
|
||||
void FusePermuteAndReorderTest::SetUp() {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
|
||||
std::tie(inputShape, inPrec) = this->GetParam();
|
||||
CreateGraph();
|
||||
}
|
||||
|
||||
const auto fusePermuteAndReorderCommonParams = ::testing::Combine(
|
||||
::testing::Values(SizeVector{1, 2, 3, 4}, SizeVector{1, 2, 3, 4, 5}),
|
||||
::testing::Values(Precision::I8, Precision::U8)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest, fusePermuteAndReorderParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||
/* FusePermuteAndReorderTest graph
|
||||
---------
|
||||
|Input |
|
||||
---------
|
||||
|
|
||||
-------------
|
||||
| --------- |
|
||||
| |Permute| |
|
||||
| --------- |
|
||||
| | |
|
||||
| --------- |
|
||||
| |Reorder| |
|
||||
| --------- |
|
||||
|-----------|
|
||||
|
|
||||
---------
|
||||
|Output |
|
||||
---------
|
||||
*/
|
||||
|
||||
void FusePermuteAndReorderTest::CreateGraph() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
|
||||
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
|
||||
auto memFmt = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||
|
||||
auto constOrder = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
auto permute = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder);
|
||||
permute->get_rt_info() = setCPUInfo({memFmt}, {memFmt}, {});
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(permute)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "PermuteReorder");
|
||||
}
|
||||
|
||||
TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
CheckPermuteCount(0);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||
|
||||
|
||||
/* FusePermuteAndReorderTest1 graph
|
||||
---------
|
||||
|Input |
|
||||
---------
|
||||
|
|
||||
---------
|
||||
|Permute|
|
||||
---------
|
||||
|
|
||||
-------------------
|
||||
| |
|
||||
| -------------
|
||||
| | --------- |
|
||||
| | |Permute| |
|
||||
--------- | --------- |
|
||||
|Reshape| | | |
|
||||
--------- | --------- |
|
||||
| | |Reorder| |
|
||||
| | --------- |
|
||||
| |-----------|
|
||||
| |
|
||||
| ---------
|
||||
| |Permute|
|
||||
| ---------
|
||||
| |
|
||||
-------- --------
|
||||
| |
|
||||
---------
|
||||
|Concat |
|
||||
---------
|
||||
|
|
||||
---------
|
||||
|Output |
|
||||
---------
|
||||
*/
|
||||
|
||||
void FusePermuteAndReorderTest1::CreateGraph() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
|
||||
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
|
||||
|
||||
auto constOrder1 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
auto permute1 = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder1);
|
||||
auto memFmt1 = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||
permute1->get_rt_info() = setCPUInfo({memFmt1}, {memFmt1}, {});
|
||||
|
||||
auto constOrder2 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
auto permute2 = std::make_shared<ngraph::opset5::Transpose>(permute1, constOrder2);
|
||||
auto memFmt2 = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||
permute2->get_rt_info() = setCPUInfo({memFmt2}, {memFmt2}, {});
|
||||
|
||||
auto constOrder3 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
auto permute3 = std::make_shared<ngraph::opset5::Transpose>(permute2, constOrder3);
|
||||
auto memFmt3 = inputShape.size() == 5 ? ncdhw : nchw;
|
||||
permute3->get_rt_info() = setCPUInfo({memFmt3}, {memFmt3}, {});
|
||||
|
||||
auto shape = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, permute3->get_output_shape(0));
|
||||
auto reshape = std::make_shared<ngraph::opset5::Reshape>(permute1, shape, false);
|
||||
|
||||
auto concat = ngraph::builder::makeConcat({permute3, reshape}, 1);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(concat)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "Permute_PermuteReorderPermute_Reshape_Concat");
|
||||
}
|
||||
|
||||
TEST_P(FusePermuteAndReorderTest1, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
CheckPermuteCount(2);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest1, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||
|
||||
|
||||
/* FusePermuteAndReorderTest2 graph
|
||||
--------- ---------
|
||||
|Input | |Input |
|
||||
--------- ---------
|
||||
| |
|
||||
| -------------
|
||||
--------- | --------- |
|
||||
|Reorder| | |Permute| |
|
||||
--------- | --------- |
|
||||
| | | |
|
||||
--------- | --------- |
|
||||
|Permute| | |Reorder| |
|
||||
--------- | --------- |
|
||||
| |-----------|
|
||||
| |
|
||||
-------- --------
|
||||
| |
|
||||
---------
|
||||
|Concat |
|
||||
---------
|
||||
|
|
||||
---------
|
||||
|Output |
|
||||
---------
|
||||
*/
|
||||
|
||||
void FusePermuteAndReorderTest2::CreateGraph() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
|
||||
auto inputShape2(inputShape);
|
||||
inputShape2[inputShape2.size() - 1] *= 2;
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape, inputShape2});
|
||||
|
||||
auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 4, 1, 2, 3} : std::vector<int64_t>{0, 3, 1, 2};
|
||||
|
||||
auto constOrder1 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
auto permute1 = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder1);
|
||||
auto memFmt1 = inputShape.size() == 5 ? ndhwc : nhwc;
|
||||
permute1->get_rt_info() = setCPUInfo({memFmt1}, {memFmt1}, {});
|
||||
|
||||
auto constOrder2 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
|
||||
auto permute2 = std::make_shared<ngraph::opset5::Transpose>(params[1], constOrder2);
|
||||
auto memFmt2 = inputShape.size() == 5 ? ncdhw : nchw;
|
||||
permute2->get_rt_info() = setCPUInfo({memFmt2}, {memFmt2}, {});
|
||||
|
||||
auto concat = ngraph::builder::makeConcat({permute1, permute2}, 1);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(concat)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "Permute_Permute_Concat");
|
||||
}
|
||||
|
||||
TEST_P(FusePermuteAndReorderTest2, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
CheckPermuteCount(1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest2, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user