[CPU] FullyConnectedNode: fix dimension normalization in case with 2D input and 3D output (#8595)

This commit is contained in:
Vladislav Golubev 2021-11-16 17:37:59 +03:00 committed by GitHub
parent d2c2b5e45c
commit 5d86cce4eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -296,11 +296,14 @@ void MKLDNNFullyConnectedNode::createDescriptorInternal(const mkldnn::memory::de
if (in_candidate.dims().size() == 3) {
auto inDims = in_candidate.dims();
auto outDims = out_candidate.dims();
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
auto normalizedOutDims = {outDims[0] * outDims[1], outDims[2]};
in_candidate = mkldnn::memory::desc(normalizedInDims, in_candidate.data_type(),
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedInDims.size()));
}
if (out_candidate.dims().size() == 3) {
auto outDims = out_candidate.dims();
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
out_candidate = mkldnn::memory::desc(normalizedOutDims, out_candidate.data_type(),
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
}

View File

@ -188,6 +188,8 @@ const std::vector<ShapeRelatedParams> IS3D = {
{{{7, 32, 120}, true}, {{120, 50}, false}},
{{{7, 32, 120}, false}, {{120, 50}, true}},
{{{7, 32, 120}, true}, {{120, 50}, true}},
{{{1, 429}, false}, {{1, 429, 1}, true}},
};
std::vector<fusingSpecificParams> fusingParamsSet3D {