[CPU] FullyConnectedNode: fix dimension normalization in case with 2D input and 3D output (#8595)
This commit is contained in:
parent
d2c2b5e45c
commit
5d86cce4eb
@ -296,11 +296,14 @@ void MKLDNNFullyConnectedNode::createDescriptorInternal(const mkldnn::memory::de
|
||||
|
||||
if (in_candidate.dims().size() == 3) {
|
||||
auto inDims = in_candidate.dims();
|
||||
auto outDims = out_candidate.dims();
|
||||
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
|
||||
auto normalizedOutDims = {outDims[0] * outDims[1], outDims[2]};
|
||||
in_candidate = mkldnn::memory::desc(normalizedInDims, in_candidate.data_type(),
|
||||
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedInDims.size()));
|
||||
}
|
||||
|
||||
if (out_candidate.dims().size() == 3) {
|
||||
auto outDims = out_candidate.dims();
|
||||
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
|
||||
out_candidate = mkldnn::memory::desc(normalizedOutDims, out_candidate.data_type(),
|
||||
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
|
||||
}
|
||||
|
@ -188,6 +188,8 @@ const std::vector<ShapeRelatedParams> IS3D = {
|
||||
{{{7, 32, 120}, true}, {{120, 50}, false}},
|
||||
{{{7, 32, 120}, false}, {{120, 50}, true}},
|
||||
{{{7, 32, 120}, true}, {{120, 50}, true}},
|
||||
|
||||
{{{1, 429}, false}, {{1, 429, 1}, true}},
|
||||
};
|
||||
|
||||
std::vector<fusingSpecificParams> fusingParamsSet3D {
|
||||
|
Loading…
Reference in New Issue
Block a user