[CPU] Fix convolution plus sum layout alignment (#19279)

This commit is contained in:
Maksim Kutakov 2023-08-23 14:29:26 +02:00 committed by GitHub
parent 982d0f43c4
commit c6a02b76be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 139 additions and 23 deletions

View File

@ -61,7 +61,7 @@ class DnnlExecutor {
}
protected:
void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
virtual void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
protected:
dnnl::primitive execPrim;

View File

@ -1362,6 +1362,16 @@ void Convolution::prepareParams() {
if (!reorderConvDesc)
return nullptr;
if (key.attr.get()->post_ops_.count(dnnl::impl::primitive_kind::sum)) {
return std::make_shared<ConvolutionSumExecutor>(
reorderConvDesc,
key.inp0->getDnnlDesc(),
key.inp1->getDnnlDesc(),
key.out->getDnnlDesc(),
engine,
key.constWeight);
}
return std::make_shared<ConvolutionExecutor>(
reorderConvDesc,
key.inp0->getDnnlDesc(),
@ -1440,6 +1450,46 @@ Convolution::ConvolutionExecutor::ConvolutionExecutor(const dnnl::primitive_desc
}
}
Convolution::ConvolutionSumExecutor::ConvolutionSumExecutor(const dnnl::primitive_desc& pd,
const dnnl::memory::desc& inMemDesc,
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine,
bool constWeight) : DnnlExecutor(pd) {
if (inMemDesc != getDnnlSrcDesc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)});
}
if (!constWeight && weightMemDesc != getDnnlWeightDesc()) {
// const weight will be reordered at first execution
inputReorders.insert({DNNL_ARG_WEIGHTS, IntermReorder(weightMemDesc, getDnnlWeightDesc(), engine)});
}
if (outMemDesc != getDnnlDstDesc()) {
// In the case of fusing sum, we have to reorder the output data before executing the primitive,
// since the output data are used as an accumulator for the covolution computations.
inputReorders.insert({DNNL_ARG_DST, IntermReorder(outMemDesc, getDnnlDstDesc(), engine)});
outputReorders.insert({DNNL_ARG_DST, IntermReorder(getDnnlDstDesc(), outMemDesc, engine)});
}
}
void Convolution::ConvolutionSumExecutor::reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm) {
auto outputMem = primArgs.at(DNNL_ARG_DST);
for (auto &inReorder : inputReorders) {
if (primArgs.count(inReorder.first)) {
dnnl::memory memDst(inReorder.second.getDstDesc(), strm.get_engine());
inReorder.second.exec(primArgs[inReorder.first], memDst, strm);
primArgs[inReorder.first] = memDst;
} else {
IE_THROW() << "DnnlExecutor has reorder for input " << inReorder.first << ", but doesn't have source memory";
}
}
execPrim.execute(strm, primArgs);
if (!outputReorders.empty()) {
outputReorders.at(DNNL_ARG_DST).exec(primArgs.at(DNNL_ARG_DST), outputMem, strm);
}
}
void Convolution::execute(dnnl::stream strm) {
if (!execPtr) {
IE_THROW() << "Can't execute Convolution node with name: " << getName() << ", because executor is not compiled";

View File

@ -94,6 +94,19 @@ private:
bool constWeight);
};
class ConvolutionSumExecutor : public DnnlExecutor {
public:
ConvolutionSumExecutor(const dnnl::primitive_desc& pd,
const dnnl::memory::desc& inMemDesc,
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine,
bool constWeight);
private:
void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm) override;
};
void prepareParams() override;
void execute(dnnl::stream strm) override;
void executeDynamicImpl(dnnl::stream strm) override;

View File

@ -23,8 +23,8 @@ typedef std::tuple<
> convSumBroadcastParamSet;
class ConcatConvSumInPlaceTest : public testing::WithParamInterface<convSumBroadcastParamSet>,
virtual public SubgraphBaseTest, public CpuTestWithFusing {
class ConvSumInPlaceTest : public testing::WithParamInterface<convSumBroadcastParamSet>,
virtual public SubgraphBaseTest, public CpuTestWithFusing {
public:
static std::string getTestCaseName(const testing::TestParamInfo<convSumBroadcastParamSet>& obj) {
InputShape convShape;
@ -137,21 +137,47 @@ protected:
protected:
ov::element::Type runtimeType;
const InferenceEngine::SizeVector _kernel = {3, 3};
const InferenceEngine::SizeVector _stride = {1, 1};
const InferenceEngine::SizeVector _dilation = {1, 1};
const std::vector<ptrdiff_t> _padBegin = {0, 0};
const std::vector<ptrdiff_t> _padEnd = {0, 0};
const size_t _convOutChannels = 64;
InferenceEngine::SizeVector _kernel = {3, 3};
InferenceEngine::SizeVector _stride = {1, 1};
InferenceEngine::SizeVector _dilation = {1, 1};
std::vector<ptrdiff_t> _padBegin = {0, 0};
std::vector<ptrdiff_t> _padEnd = {0, 0};
size_t _convOutChannels = 64;
};
TEST_P(ConcatConvSumInPlaceTest, CompareWithRefs) {
TEST_P(ConvSumInPlaceTest, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, "Convolution");
}
class ConcatConvSumInPlaceTestInt8 : public ConcatConvSumInPlaceTest {
class ConvSumInPlaceStrided : public ConvSumInPlaceTest {
public:
ConvSumInPlaceStrided() {
_kernel = {1, 1};
_stride = {2, 2};
_convOutChannels = 128;
rel_threshold = 1e-4;
}
protected:
bool primTypeCheck(std::string primType) const override {
auto isaType = getISA(runtimeType == ov::element::Type_t::f32);
if (isaType == "")
return primType == "ref";
else
return primType == makeSelectedTypeStr(std::string("jit_") + isaType + std::string("_1x1"), runtimeType)
|| primType == makeSelectedTypeStr(std::string("brgconv_") + isaType+ std::string("_1x1"), runtimeType);
}
};
TEST_P(ConvSumInPlaceStrided, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, "Convolution");
}
class ConvSumInPlaceTestInt8 : public ConvSumInPlaceTest {
public:
ngraph::ParameterVector makeParams() override {
ngraph::ParameterVector outs(2);
@ -201,7 +227,7 @@ public:
void SetUp() override {
abs_threshold = 1.001f;
using ngraph::pass::ConvertPrecision;
ConcatConvSumInPlaceTest::SetUp();
ConvSumInPlaceTest::SetUp();
functionRefs = function->clone();
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::u8, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
@ -209,13 +235,13 @@ public:
}
};
TEST_P(ConcatConvSumInPlaceTestInt8, CompareWithRefs) {
TEST_P(ConvSumInPlaceTestInt8, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, "Convolution");
}
class ConcatConvSumInPlaceTestSeveralConsumers : public ConcatConvSumInPlaceTest {
class ConvSumInPlaceTestSeveralConsumers : public ConvSumInPlaceTest {
public:
std::shared_ptr<ngraph::Node> addSum(std::shared_ptr<ngraph::Node> lastNode, const ngraph::ParameterVector& inputParams) override {
auto sum = std::make_shared<ngraph::opset3::Add>(lastNode, inputParams[1]);
@ -226,7 +252,7 @@ public:
}
};
TEST_P(ConcatConvSumInPlaceTestSeveralConsumers, CompareWithRefs) {
TEST_P(ConvSumInPlaceTestSeveralConsumers, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
@ -356,41 +382,68 @@ const std::vector<InputShape> secondInp = {
},
};
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest,
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConvSumInPlaceTest,
::testing::Combine(
::testing::Values(convInpShape),
::testing::ValuesIn(secondInp),
::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(cpuEmptyPluginConfig)),
ConcatConvSumInPlaceTest::getTestCaseName);
ConvSumInPlaceTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest,
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConvSumInPlaceTest,
::testing::Combine(
::testing::Values(convInpShape),
::testing::ValuesIn(secondInp),
::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSetBF16),
::testing::Values(cpuBF16PluginConfig)),
ConcatConvSumInPlaceTest::getTestCaseName);
ConvSumInPlaceTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTestInt8,
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConvSumInPlaceTestInt8,
::testing::Combine(
::testing::Values(convInpShape),
::testing::ValuesIn(secondInp),
::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(cpuEmptyPluginConfig)),
ConcatConvSumInPlaceTest::getTestCaseName);
ConvSumInPlaceTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Several_Consumers, ConcatConvSumInPlaceTestSeveralConsumers,
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Several_Consumers, ConvSumInPlaceTestSeveralConsumers,
::testing::Combine(
::testing::Values(convInpShape),
::testing::ValuesIn(secondInp),
::testing::Values(true),
::testing::Values(emptyFusingSpec),
::testing::Values(cpuEmptyPluginConfig)),
ConcatConvSumInPlaceTest::getTestCaseName);
ConvSumInPlaceTest::getTestCaseName);
InputShape convInpShapeStrided = {
//dynamic shapes
{-1, 64, -1, -1},
{ //target static shapes
{1, 64, 147, 147},
{1, 64, 147, 147},
}
};
InputShape secondInpStrided = {
//dynamic shapes
{-1, 128, -1, -1},
{ //target static shapes
{1, 128, 74, 74},
{1, 128, 74, 1}
}
};
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Strided, ConvSumInPlaceStrided,
::testing::Combine(
::testing::Values(convInpShapeStrided),
::testing::Values(secondInpStrided),
::testing::Values(true),
::testing::Values(emptyFusingSpec),
::testing::Values(cpuEmptyPluginConfig)),
ConvSumInPlaceTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions