[CPU] Fix FuseConvolutionSumAndConvolutionSumActivation (#9595)

This commit is contained in:
Egor Shulman 2022-01-21 17:24:46 +03:00 committed by GitHub
parent d6bcfb7b8f
commit cf84f43b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 0 deletions

View File

@ -1119,10 +1119,21 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
isSuitableParent2 = isSuitableParent2 && canFuseSum(binConvNode2, graphNode);
}
auto checkFusedWithSum = [](MKLDNNConvolutionNode* conv) -> bool {
for (const auto& node : conv->getFusedWith()) {
const auto eltwise = std::dynamic_pointer_cast<MKLDNNEltwiseNode>(node);
if (eltwise && eltwise->isSpecialConvolutionAddFusing())
return true;
}
return false;
};
auto* convNode1 = dynamic_cast<MKLDNNConvolutionNode *>(parent1.get());
if (convNode1) {
if (!convNode1->canBeExecutedInInt8()) {
isSuitableParent1 = isSuitableParent1 && convNode1->getFusedWith().empty();
} else {
isSuitableParent1 = isSuitableParent1 && !checkFusedWithSum(convNode1);
}
}
@ -1130,6 +1141,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
if (convNode2) {
if (!convNode2->canBeExecutedInInt8()) {
isSuitableParent2 = isSuitableParent2 && convNode2->getFusedWith().empty();
} else {
isSuitableParent2 = isSuitableParent2 && !checkFusedWithSum(convNode2);
}
}

View File

@ -0,0 +1,82 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace ngraph;
using ngraph::helpers::EltwiseTypes;
namespace SubgraphTestsDefinitions {
/* We can't fuse EltwiseAdd several times into one convolution
FQ1 FQ2
\ /
ADD1 CONV1 [canBeExecutedInInt8]
\ /
\ /
ADD2 CONV2 [canBeExecutedInInt8]
\ /
\ /
ADD3
|
RELU
|
RESULT
*/
class ConvsAndSums : virtual public LayerTestsUtils::LayerTestsCommon {
protected:
void SetUp() override {
InferenceEngine::Precision netPrecision = InferenceEngine::Precision::FP32;
targetDevice = CommonTestUtils::DEVICE_CPU;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {{1, 512, 32}, {1, 128, 32}});
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
auto FQ = ngraph::builder::makeFakeQuantize(paramOuts[1], ngPrc, 256, {}, {-2.8215785026550293}, {2.799535036087036},
{-2.8215785026550293}, {2.799535036087036});
auto FQ_0 = ngraph::builder::makeFakeQuantize(paramOuts[1], ngPrc, 256, {}, {-5.031249523162842}, {4.991942882537842},
{-5.031249523162842}, {4.991942882537842});
auto Add_0 = ngraph::builder::makeEltwise(FQ_0, FQ, EltwiseTypes::ADD);
auto FQ_1 = ngraph::builder::makeFakeQuantize(paramOuts[0], ngPrc, 256, {}, {-2.122633457183838}, {2.106050491333008},
{-2.122633457183838}, {2.106050491333008});
auto Const = ngraph::builder::makeConstant(ngPrc, {128, 512, 1}, std::vector<float>{-0.0512377955019474}, false);
auto FQ_2 = ngraph::builder::makeFakeQuantize(Const, ngPrc, 255, {128, 1, 1}, {-0.56387859582901}, {0.56387859582901},
{-0.56387859582901}, {0.56387859582901});
auto Conv = std::make_shared<ngraph::opset1::Convolution>(FQ_1, FQ_2, Strides{1}, CoordinateDiff{0}, CoordinateDiff{0}, Strides{1});
auto Add = ngraph::builder::makeEltwise(Add_0, Conv, EltwiseTypes::ADD);
auto FQ_11 = ngraph::builder::makeFakeQuantize(paramOuts[0], ngPrc, 256, {}, {-3.2050728797912598}, {3.1800332069396973},
{-3.2050728797912598}, {3.1800332069396973});
auto Const_ = ngraph::builder::makeConstant(ngPrc, {128, 512, 1}, std::vector<float>{-0.001183388871140778}, false);
auto FQ_22 = ngraph::builder::makeFakeQuantize(Const_, ngPrc, 255, {128, 1, 1}, {-0.325547456741333}, {0.325547456741333},
{-0.325547456741333}, {0.325547456741333});
auto Conv2 = std::make_shared<ngraph::opset1::Convolution>(FQ_11, FQ_22, Strides{1}, CoordinateDiff{0}, CoordinateDiff{0}, Strides{1});
auto Add2 = ngraph::builder::makeEltwise(Add, Conv2, EltwiseTypes::ADD);
auto relu3 = ngraph::builder::makeActivation(Add2, ngPrc, ngraph::helpers::ActivationTypes::Relu);
auto result = std::make_shared<ngraph::opset1::Result>(relu3);
function = std::make_shared<ngraph::Function>(result, params, "SimpleNet");
}
};
TEST_F(ConvsAndSums, smoke_CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run();
}
} // namespace SubgraphTestsDefinitions