[CPU] FullyConnectedNode: fix dimension normalization in case with 2D input and 3D output (#8595)
This commit is contained in:
parent
d2c2b5e45c
commit
5d86cce4eb
@ -296,13 +296,16 @@ void MKLDNNFullyConnectedNode::createDescriptorInternal(const mkldnn::memory::de
|
|||||||
|
|
||||||
if (in_candidate.dims().size() == 3) {
|
if (in_candidate.dims().size() == 3) {
|
||||||
auto inDims = in_candidate.dims();
|
auto inDims = in_candidate.dims();
|
||||||
auto outDims = out_candidate.dims();
|
|
||||||
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
|
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
|
||||||
auto normalizedOutDims = {outDims[0] * outDims[1], outDims[2]};
|
|
||||||
in_candidate = mkldnn::memory::desc(normalizedInDims, in_candidate.data_type(),
|
in_candidate = mkldnn::memory::desc(normalizedInDims, in_candidate.data_type(),
|
||||||
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedInDims.size()));
|
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedInDims.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out_candidate.dims().size() == 3) {
|
||||||
|
auto outDims = out_candidate.dims();
|
||||||
|
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
|
||||||
out_candidate = mkldnn::memory::desc(normalizedOutDims, out_candidate.data_type(),
|
out_candidate = mkldnn::memory::desc(normalizedOutDims, out_candidate.data_type(),
|
||||||
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
|
MKLDNNExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
mkldnn::memory::desc wgh_candidate(MKLDNNExtensionUtils::convertToDnnlDims(weightsDims), wdt, mkldnn::memory::format_tag::any);
|
mkldnn::memory::desc wgh_candidate(MKLDNNExtensionUtils::convertToDnnlDims(weightsDims), wdt, mkldnn::memory::format_tag::any);
|
||||||
|
@ -188,6 +188,8 @@ const std::vector<ShapeRelatedParams> IS3D = {
|
|||||||
{{{7, 32, 120}, true}, {{120, 50}, false}},
|
{{{7, 32, 120}, true}, {{120, 50}, false}},
|
||||||
{{{7, 32, 120}, false}, {{120, 50}, true}},
|
{{{7, 32, 120}, false}, {{120, 50}, true}},
|
||||||
{{{7, 32, 120}, true}, {{120, 50}, true}},
|
{{{7, 32, 120}, true}, {{120, 50}, true}},
|
||||||
|
|
||||||
|
{{{1, 429}, false}, {{1, 429, 1}, true}},
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<fusingSpecificParams> fusingParamsSet3D {
|
std::vector<fusingSpecificParams> fusingParamsSet3D {
|
||||||
|
Loading…
Reference in New Issue
Block a user