[CPU] Relax threshold for using brgconv1x1 instead of inner product (#14351)

* relax condition

* modify new shape due to threshold changed

* rename M to an exact name

* apply review comments
This commit is contained in:
Luo Cheng 2022-12-06 16:23:32 +08:00 committed by GitHub
parent f12765ebef
commit bf0d2cea4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 13 deletions

View File

@ -817,12 +817,22 @@ bool FullyConnected::canBeExecutedInConv1x1() const {
const auto& srcDims = srcMemPtr->getStaticDims();
auto weightMemPtr = getParentEdgesAtPort(1)[0]->getMemoryPtr();
const auto& weightDims = weightMemPtr->getStaticDims();
Dim M, N, K;
M = srcDims[inRank - 2];
// for original inner product semantics:
// when input is 2D tensor
// M in oneDNN will map to widthInConv
// when input is 3D tensor
// M in oneDNN will map to widthInConv*minibatch
// currently nwc mapping in brg:
// when input is 2D tensor
// widthInConv will map to 'w', 'n' will be 1
// when input is 3D tensor
// widthInConv will map to 'w', 'n' will be minibatch
Dim widthInConv, N, K;
widthInConv = srcDims[inRank - 2];
K = srcDims[inRank - 1];
N = weightDims[0];
if (!(M >= 49 && M <= 3136 &&
if (!(widthInConv >= 2 && widthInConv <= 3136 &&
K >= 96 && K <= 4096 &&
N >= 96 && N <= K * 4))
retVal = false;

View File

@ -408,8 +408,8 @@ const std::vector<ShapeRelatedParams> IS3D_smoke = {
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, true}},
// needed by 'IS3D_Brgconv1x1_smoke'
{static_shapes_to_test_representation({{1, 32, 120}, {120, 120}}), {false, false}},
{static_shapes_to_test_representation({{3, 29, 120}, {120, 120}}), {false, false}},
{static_shapes_to_test_representation({{1, 1, 120}, {120, 120}}), {false, false}},
{static_shapes_to_test_representation({{3, 1, 120}, {120, 120}}), {false, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {true, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {false, true}},
@ -548,8 +548,8 @@ INSTANTIATE_TEST_SUITE_P(nightly_FC_3D_BF16, MatMulLayerCPUTest, testParams3DBF1
const std::vector<ShapeRelatedParams> IS2D_Brgemm_smoke = {
// needed by 'IS2D_Brgconv1x1_smoke'
{static_shapes_to_test_representation({{29, 120}, {120, 120}}), {true, false}},
{static_shapes_to_test_representation({{39, 120}, {120, 120}}), {true, false}},
{static_shapes_to_test_representation({{1, 120}, {120, 120}}), {true, false}},
{static_shapes_to_test_representation({{1, 128}, {128, 166}}), {true, false}},
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, false}},
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, true}},
@ -624,9 +624,9 @@ const std::vector<ShapeRelatedParams> IS2D_Brgconv1x1_smoke = {
{
{
// ip->brg->ip->brg
// {39, 120}, {29, 120} are covered in 'IS2D_Brgemm_smoke' which is ip
// {1, 120} are covered in 'IS2D_Brgemm_smoke' which is ip
// {49, 120}, {79, 120} are covered above which is brg1x1
{{-1, -1}, {{39, 120}, {49, 120}, {29, 120}, {79, 120}}},
{{-1, -1}, {{1, 120}, {49, 120}, {1, 120}, {79, 120}}},
{{120, 120}, {{120, 120}, {120, 120}, {120, 120}, {120, 120}}}
},
{false, false}
@ -634,7 +634,7 @@ const std::vector<ShapeRelatedParams> IS2D_Brgconv1x1_smoke = {
{
{
// ip->brg->ip(cached)->brg(cached)
{{{0, 200}, {0, 200}}, {{18, 128}, {199, 128}, {18, 128}, {199, 128}}},
{{{0, 200}, {0, 200}}, {{1, 128}, {199, 128}, {1, 128}, {199, 128}}},
{{128, 166}, {{128, 166}, {128, 166}}}
},
{true, true}
@ -670,9 +670,9 @@ const std::vector<ShapeRelatedParams> IS3D_Brgconv1x1_smoke = {
{
{
// ip->brg->ip->brg
// {1, 32, 120}, {3, 29, 120} are covered in 'IS3D_smoke' which is ip
// {1, 1, 120}, {3, 1, 120} are covered in 'IS3D_smoke' which is ip
// {2, 49, 120}, {4, 79, 120} are covered above which is brg1x1
{{-1, -1, -1}, {{1, 32, 120}, {2, 49, 120}, {3, 29, 120}, {4, 79, 120}}},
{{-1, -1, -1}, {{1, 1, 120}, {2, 49, 120}, {3, 1, 120}, {4, 79, 120}}},
{{120, 120}, {{120, 120}, {120, 120}, {120, 120}, {120, 120}}}
},
{false, false}
@ -710,9 +710,30 @@ const auto testParams3D_Brgconv1x1_smoke = ::testing::Combine(fullyConnectedPara
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_Brgconv1x1, MatMulLayerCPUTest, testParams3D_Brgconv1x1_smoke, MatMulLayerCPUTest::getTestCaseName);
const std::vector<ShapeRelatedParams> IS2D_Brgemm_Amx_smoke = {
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, false}},
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, true}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, false}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, true}},
const auto fullyConnectedParams2D_Brgemm_Amx_smoke = ::testing::Combine(::testing::ValuesIn(IS2D_Brgemm_smoke),
{
{
{{-1, -1}, {{12, 16}, {25, 16}, {12, 16}, {25, 16}}},
{{16, 35}, {{16, 35}, {16, 35}, {16, 35}, {16, 35}}}
},
{false, false}
},
{
{
{{{0, 50}, {0, 50}}, {{17, 48}, {15, 48}}},
{{48, 15}, {{48, 15}, {48, 15}}}
},
{true, true}
},
};
const auto fullyConnectedParams2D_Brgemm_Amx_smoke = ::testing::Combine(::testing::ValuesIn(IS2D_Brgemm_Amx_smoke),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),