[CPU] Relax threshold for using brgconv1x1 instead of inner product (#14351)
* relax condition * modify new shape due to threshold changed * rename M to an exact name * apply review comments
This commit is contained in:
parent
f12765ebef
commit
bf0d2cea4b
@ -817,12 +817,22 @@ bool FullyConnected::canBeExecutedInConv1x1() const {
|
||||
const auto& srcDims = srcMemPtr->getStaticDims();
|
||||
auto weightMemPtr = getParentEdgesAtPort(1)[0]->getMemoryPtr();
|
||||
const auto& weightDims = weightMemPtr->getStaticDims();
|
||||
Dim M, N, K;
|
||||
M = srcDims[inRank - 2];
|
||||
// for original inner product semantics:
|
||||
// when input is 2D tensor
|
||||
// M in oneDNN will map to widthInConv
|
||||
// when input is 3D tensor
|
||||
// M in oneDNN will map to widthInConv*minibatch
|
||||
// currently nwc mapping in brg:
|
||||
// when input is 2D tensor
|
||||
// widthInConv will map to 'w', 'n' will be 1
|
||||
// when input is 3D tensor
|
||||
// widthInConv will map to 'w', 'n' will be minibatch
|
||||
Dim widthInConv, N, K;
|
||||
widthInConv = srcDims[inRank - 2];
|
||||
K = srcDims[inRank - 1];
|
||||
N = weightDims[0];
|
||||
|
||||
if (!(M >= 49 && M <= 3136 &&
|
||||
if (!(widthInConv >= 2 && widthInConv <= 3136 &&
|
||||
K >= 96 && K <= 4096 &&
|
||||
N >= 96 && N <= K * 4))
|
||||
retVal = false;
|
||||
|
@ -408,8 +408,8 @@ const std::vector<ShapeRelatedParams> IS3D_smoke = {
|
||||
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, false}},
|
||||
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, true}},
|
||||
// needed by 'IS3D_Brgconv1x1_smoke'
|
||||
{static_shapes_to_test_representation({{1, 32, 120}, {120, 120}}), {false, false}},
|
||||
{static_shapes_to_test_representation({{3, 29, 120}, {120, 120}}), {false, false}},
|
||||
{static_shapes_to_test_representation({{1, 1, 120}, {120, 120}}), {false, false}},
|
||||
{static_shapes_to_test_representation({{3, 1, 120}, {120, 120}}), {false, false}},
|
||||
|
||||
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {true, false}},
|
||||
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {false, true}},
|
||||
@ -548,8 +548,8 @@ INSTANTIATE_TEST_SUITE_P(nightly_FC_3D_BF16, MatMulLayerCPUTest, testParams3DBF1
|
||||
|
||||
const std::vector<ShapeRelatedParams> IS2D_Brgemm_smoke = {
|
||||
// needed by 'IS2D_Brgconv1x1_smoke'
|
||||
{static_shapes_to_test_representation({{29, 120}, {120, 120}}), {true, false}},
|
||||
{static_shapes_to_test_representation({{39, 120}, {120, 120}}), {true, false}},
|
||||
{static_shapes_to_test_representation({{1, 120}, {120, 120}}), {true, false}},
|
||||
{static_shapes_to_test_representation({{1, 128}, {128, 166}}), {true, false}},
|
||||
|
||||
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, false}},
|
||||
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, true}},
|
||||
@ -624,9 +624,9 @@ const std::vector<ShapeRelatedParams> IS2D_Brgconv1x1_smoke = {
|
||||
{
|
||||
{
|
||||
// ip->brg->ip->brg
|
||||
// {39, 120}, {29, 120} are covered in 'IS2D_Brgemm_smoke' which is ip
|
||||
// {1, 120} are covered in 'IS2D_Brgemm_smoke' which is ip
|
||||
// {49, 120}, {79, 120} are covered above which is brg1x1
|
||||
{{-1, -1}, {{39, 120}, {49, 120}, {29, 120}, {79, 120}}},
|
||||
{{-1, -1}, {{1, 120}, {49, 120}, {1, 120}, {79, 120}}},
|
||||
{{120, 120}, {{120, 120}, {120, 120}, {120, 120}, {120, 120}}}
|
||||
},
|
||||
{false, false}
|
||||
@ -634,7 +634,7 @@ const std::vector<ShapeRelatedParams> IS2D_Brgconv1x1_smoke = {
|
||||
{
|
||||
{
|
||||
// ip->brg->ip(cached)->brg(cached)
|
||||
{{{0, 200}, {0, 200}}, {{18, 128}, {199, 128}, {18, 128}, {199, 128}}},
|
||||
{{{0, 200}, {0, 200}}, {{1, 128}, {199, 128}, {1, 128}, {199, 128}}},
|
||||
{{128, 166}, {{128, 166}, {128, 166}}}
|
||||
},
|
||||
{true, true}
|
||||
@ -670,9 +670,9 @@ const std::vector<ShapeRelatedParams> IS3D_Brgconv1x1_smoke = {
|
||||
{
|
||||
{
|
||||
// ip->brg->ip->brg
|
||||
// {1, 32, 120}, {3, 29, 120} are covered in 'IS3D_smoke' which is ip
|
||||
// {1, 1, 120}, {3, 1, 120} are covered in 'IS3D_smoke' which is ip
|
||||
// {2, 49, 120}, {4, 79, 120} are covered above which is brg1x1
|
||||
{{-1, -1, -1}, {{1, 32, 120}, {2, 49, 120}, {3, 29, 120}, {4, 79, 120}}},
|
||||
{{-1, -1, -1}, {{1, 1, 120}, {2, 49, 120}, {3, 1, 120}, {4, 79, 120}}},
|
||||
{{120, 120}, {{120, 120}, {120, 120}, {120, 120}, {120, 120}}}
|
||||
},
|
||||
{false, false}
|
||||
@ -710,9 +710,30 @@ const auto testParams3D_Brgconv1x1_smoke = ::testing::Combine(fullyConnectedPara
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_Brgconv1x1, MatMulLayerCPUTest, testParams3D_Brgconv1x1_smoke, MatMulLayerCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<ShapeRelatedParams> IS2D_Brgemm_Amx_smoke = {
|
||||
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, false}},
|
||||
{static_shapes_to_test_representation({{59, 16}, {16, 120}}), {true, true}},
|
||||
|
||||
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, false}},
|
||||
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, true}},
|
||||
|
||||
const auto fullyConnectedParams2D_Brgemm_Amx_smoke = ::testing::Combine(::testing::ValuesIn(IS2D_Brgemm_smoke),
|
||||
{
|
||||
{
|
||||
{{-1, -1}, {{12, 16}, {25, 16}, {12, 16}, {25, 16}}},
|
||||
{{16, 35}, {{16, 35}, {16, 35}, {16, 35}, {16, 35}}}
|
||||
},
|
||||
{false, false}
|
||||
},
|
||||
{
|
||||
{
|
||||
{{{0, 50}, {0, 50}}, {{17, 48}, {15, 48}}},
|
||||
{{48, 15}, {{48, 15}, {48, 15}}}
|
||||
},
|
||||
{true, true}
|
||||
},
|
||||
};
|
||||
|
||||
const auto fullyConnectedParams2D_Brgemm_Amx_smoke = ::testing::Combine(::testing::ValuesIn(IS2D_Brgemm_Amx_smoke),
|
||||
::testing::Values(ElementType::f32),
|
||||
::testing::Values(ElementType::undefined),
|
||||
::testing::Values(ElementType::undefined),
|
||||
|
Loading…
Reference in New Issue
Block a user