[CPU] Fix convolution plus sum layout alignment (#19279)
This commit is contained in:
parent
982d0f43c4
commit
c6a02b76be
@ -61,7 +61,7 @@ class DnnlExecutor {
|
||||
}
|
||||
|
||||
protected:
|
||||
void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
|
||||
virtual void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
|
||||
|
||||
protected:
|
||||
dnnl::primitive execPrim;
|
||||
|
@ -1362,6 +1362,16 @@ void Convolution::prepareParams() {
|
||||
if (!reorderConvDesc)
|
||||
return nullptr;
|
||||
|
||||
if (key.attr.get()->post_ops_.count(dnnl::impl::primitive_kind::sum)) {
|
||||
return std::make_shared<ConvolutionSumExecutor>(
|
||||
reorderConvDesc,
|
||||
key.inp0->getDnnlDesc(),
|
||||
key.inp1->getDnnlDesc(),
|
||||
key.out->getDnnlDesc(),
|
||||
engine,
|
||||
key.constWeight);
|
||||
}
|
||||
|
||||
return std::make_shared<ConvolutionExecutor>(
|
||||
reorderConvDesc,
|
||||
key.inp0->getDnnlDesc(),
|
||||
@ -1440,6 +1450,46 @@ Convolution::ConvolutionExecutor::ConvolutionExecutor(const dnnl::primitive_desc
|
||||
}
|
||||
}
|
||||
|
||||
Convolution::ConvolutionSumExecutor::ConvolutionSumExecutor(const dnnl::primitive_desc& pd,
|
||||
const dnnl::memory::desc& inMemDesc,
|
||||
const dnnl::memory::desc& weightMemDesc,
|
||||
const dnnl::memory::desc& outMemDesc,
|
||||
const dnnl::engine& engine,
|
||||
bool constWeight) : DnnlExecutor(pd) {
|
||||
if (inMemDesc != getDnnlSrcDesc()) {
|
||||
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)});
|
||||
}
|
||||
|
||||
if (!constWeight && weightMemDesc != getDnnlWeightDesc()) {
|
||||
// const weight will be reordered at first execution
|
||||
inputReorders.insert({DNNL_ARG_WEIGHTS, IntermReorder(weightMemDesc, getDnnlWeightDesc(), engine)});
|
||||
}
|
||||
|
||||
if (outMemDesc != getDnnlDstDesc()) {
|
||||
// In the case of fusing sum, we have to reorder the output data before executing the primitive,
|
||||
// since the output data are used as an accumulator for the covolution computations.
|
||||
inputReorders.insert({DNNL_ARG_DST, IntermReorder(outMemDesc, getDnnlDstDesc(), engine)});
|
||||
outputReorders.insert({DNNL_ARG_DST, IntermReorder(getDnnlDstDesc(), outMemDesc, engine)});
|
||||
}
|
||||
}
|
||||
|
||||
void Convolution::ConvolutionSumExecutor::reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm) {
|
||||
auto outputMem = primArgs.at(DNNL_ARG_DST);
|
||||
for (auto &inReorder : inputReorders) {
|
||||
if (primArgs.count(inReorder.first)) {
|
||||
dnnl::memory memDst(inReorder.second.getDstDesc(), strm.get_engine());
|
||||
inReorder.second.exec(primArgs[inReorder.first], memDst, strm);
|
||||
primArgs[inReorder.first] = memDst;
|
||||
} else {
|
||||
IE_THROW() << "DnnlExecutor has reorder for input " << inReorder.first << ", but doesn't have source memory";
|
||||
}
|
||||
}
|
||||
execPrim.execute(strm, primArgs);
|
||||
if (!outputReorders.empty()) {
|
||||
outputReorders.at(DNNL_ARG_DST).exec(primArgs.at(DNNL_ARG_DST), outputMem, strm);
|
||||
}
|
||||
}
|
||||
|
||||
void Convolution::execute(dnnl::stream strm) {
|
||||
if (!execPtr) {
|
||||
IE_THROW() << "Can't execute Convolution node with name: " << getName() << ", because executor is not compiled";
|
||||
|
@ -94,6 +94,19 @@ private:
|
||||
bool constWeight);
|
||||
};
|
||||
|
||||
class ConvolutionSumExecutor : public DnnlExecutor {
|
||||
public:
|
||||
ConvolutionSumExecutor(const dnnl::primitive_desc& pd,
|
||||
const dnnl::memory::desc& inMemDesc,
|
||||
const dnnl::memory::desc& weightMemDesc,
|
||||
const dnnl::memory::desc& outMemDesc,
|
||||
const dnnl::engine& engine,
|
||||
bool constWeight);
|
||||
|
||||
private:
|
||||
void reorder_exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm) override;
|
||||
};
|
||||
|
||||
void prepareParams() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
void executeDynamicImpl(dnnl::stream strm) override;
|
||||
|
@ -23,8 +23,8 @@ typedef std::tuple<
|
||||
> convSumBroadcastParamSet;
|
||||
|
||||
|
||||
class ConcatConvSumInPlaceTest : public testing::WithParamInterface<convSumBroadcastParamSet>,
|
||||
virtual public SubgraphBaseTest, public CpuTestWithFusing {
|
||||
class ConvSumInPlaceTest : public testing::WithParamInterface<convSumBroadcastParamSet>,
|
||||
virtual public SubgraphBaseTest, public CpuTestWithFusing {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<convSumBroadcastParamSet>& obj) {
|
||||
InputShape convShape;
|
||||
@ -137,21 +137,47 @@ protected:
|
||||
|
||||
protected:
|
||||
ov::element::Type runtimeType;
|
||||
const InferenceEngine::SizeVector _kernel = {3, 3};
|
||||
const InferenceEngine::SizeVector _stride = {1, 1};
|
||||
const InferenceEngine::SizeVector _dilation = {1, 1};
|
||||
const std::vector<ptrdiff_t> _padBegin = {0, 0};
|
||||
const std::vector<ptrdiff_t> _padEnd = {0, 0};
|
||||
const size_t _convOutChannels = 64;
|
||||
InferenceEngine::SizeVector _kernel = {3, 3};
|
||||
InferenceEngine::SizeVector _stride = {1, 1};
|
||||
InferenceEngine::SizeVector _dilation = {1, 1};
|
||||
std::vector<ptrdiff_t> _padBegin = {0, 0};
|
||||
std::vector<ptrdiff_t> _padEnd = {0, 0};
|
||||
size_t _convOutChannels = 64;
|
||||
};
|
||||
|
||||
TEST_P(ConcatConvSumInPlaceTest, CompareWithRefs) {
|
||||
TEST_P(ConvSumInPlaceTest, CompareWithRefs) {
|
||||
run();
|
||||
|
||||
CheckPluginRelatedResults(compiledModel, "Convolution");
|
||||
}
|
||||
|
||||
class ConcatConvSumInPlaceTestInt8 : public ConcatConvSumInPlaceTest {
|
||||
class ConvSumInPlaceStrided : public ConvSumInPlaceTest {
|
||||
public:
|
||||
ConvSumInPlaceStrided() {
|
||||
_kernel = {1, 1};
|
||||
_stride = {2, 2};
|
||||
_convOutChannels = 128;
|
||||
rel_threshold = 1e-4;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool primTypeCheck(std::string primType) const override {
|
||||
auto isaType = getISA(runtimeType == ov::element::Type_t::f32);
|
||||
if (isaType == "")
|
||||
return primType == "ref";
|
||||
else
|
||||
return primType == makeSelectedTypeStr(std::string("jit_") + isaType + std::string("_1x1"), runtimeType)
|
||||
|| primType == makeSelectedTypeStr(std::string("brgconv_") + isaType+ std::string("_1x1"), runtimeType);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConvSumInPlaceStrided, CompareWithRefs) {
|
||||
run();
|
||||
|
||||
CheckPluginRelatedResults(compiledModel, "Convolution");
|
||||
}
|
||||
|
||||
class ConvSumInPlaceTestInt8 : public ConvSumInPlaceTest {
|
||||
public:
|
||||
ngraph::ParameterVector makeParams() override {
|
||||
ngraph::ParameterVector outs(2);
|
||||
@ -201,7 +227,7 @@ public:
|
||||
void SetUp() override {
|
||||
abs_threshold = 1.001f;
|
||||
using ngraph::pass::ConvertPrecision;
|
||||
ConcatConvSumInPlaceTest::SetUp();
|
||||
ConvSumInPlaceTest::SetUp();
|
||||
functionRefs = function->clone();
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::u8, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
|
||||
@ -209,13 +235,13 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatConvSumInPlaceTestInt8, CompareWithRefs) {
|
||||
TEST_P(ConvSumInPlaceTestInt8, CompareWithRefs) {
|
||||
run();
|
||||
|
||||
CheckPluginRelatedResults(compiledModel, "Convolution");
|
||||
}
|
||||
|
||||
class ConcatConvSumInPlaceTestSeveralConsumers : public ConcatConvSumInPlaceTest {
|
||||
class ConvSumInPlaceTestSeveralConsumers : public ConvSumInPlaceTest {
|
||||
public:
|
||||
std::shared_ptr<ngraph::Node> addSum(std::shared_ptr<ngraph::Node> lastNode, const ngraph::ParameterVector& inputParams) override {
|
||||
auto sum = std::make_shared<ngraph::opset3::Add>(lastNode, inputParams[1]);
|
||||
@ -226,7 +252,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatConvSumInPlaceTestSeveralConsumers, CompareWithRefs) {
|
||||
TEST_P(ConvSumInPlaceTestSeveralConsumers, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
@ -356,41 +382,68 @@ const std::vector<InputShape> secondInp = {
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConvSumInPlaceTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(convInpShape),
|
||||
::testing::ValuesIn(secondInp),
|
||||
::testing::Values(true, false),
|
||||
::testing::ValuesIn(fusingParamsSet),
|
||||
::testing::Values(cpuEmptyPluginConfig)),
|
||||
ConcatConvSumInPlaceTest::getTestCaseName);
|
||||
ConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConvSumInPlaceTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(convInpShape),
|
||||
::testing::ValuesIn(secondInp),
|
||||
::testing::Values(true, false),
|
||||
::testing::ValuesIn(fusingParamsSetBF16),
|
||||
::testing::Values(cpuBF16PluginConfig)),
|
||||
ConcatConvSumInPlaceTest::getTestCaseName);
|
||||
ConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTestInt8,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConvSumInPlaceTestInt8,
|
||||
::testing::Combine(
|
||||
::testing::Values(convInpShape),
|
||||
::testing::ValuesIn(secondInp),
|
||||
::testing::Values(true, false),
|
||||
::testing::ValuesIn(fusingParamsSet),
|
||||
::testing::Values(cpuEmptyPluginConfig)),
|
||||
ConcatConvSumInPlaceTest::getTestCaseName);
|
||||
ConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Several_Consumers, ConcatConvSumInPlaceTestSeveralConsumers,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Several_Consumers, ConvSumInPlaceTestSeveralConsumers,
|
||||
::testing::Combine(
|
||||
::testing::Values(convInpShape),
|
||||
::testing::ValuesIn(secondInp),
|
||||
::testing::Values(true),
|
||||
::testing::Values(emptyFusingSpec),
|
||||
::testing::Values(cpuEmptyPluginConfig)),
|
||||
ConcatConvSumInPlaceTest::getTestCaseName);
|
||||
ConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
InputShape convInpShapeStrided = {
|
||||
//dynamic shapes
|
||||
{-1, 64, -1, -1},
|
||||
{ //target static shapes
|
||||
{1, 64, 147, 147},
|
||||
{1, 64, 147, 147},
|
||||
}
|
||||
};
|
||||
|
||||
InputShape secondInpStrided = {
|
||||
//dynamic shapes
|
||||
{-1, 128, -1, -1},
|
||||
{ //target static shapes
|
||||
{1, 128, 74, 74},
|
||||
{1, 128, 74, 1}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Strided, ConvSumInPlaceStrided,
|
||||
::testing::Combine(
|
||||
::testing::Values(convInpShapeStrided),
|
||||
::testing::Values(secondInpStrided),
|
||||
::testing::Values(true),
|
||||
::testing::Values(emptyFusingSpec),
|
||||
::testing::Values(cpuEmptyPluginConfig)),
|
||||
ConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user