[CPU] Single place to fill tails of load emitter (#14846)

This commit is contained in:
Chenhu Wang
2023-03-07 01:31:39 +08:00
committed by GitHub
parent 5269cb37d8
commit 6e7bef529f
2 changed files with 28 additions and 9 deletions

View File

@@ -150,6 +150,11 @@ void jit_load_emitter::emit_isa(const Xbyak::Reg64 &reg_src, const int out_vec_i
break;
}
}
if (is_fill_) {
int dword_num_loaded = (src_prc_ != dst_prc_) ? load_num_ : (load_size_ / sizeof(float));
fill_with_default(Vmm(out_vec_idx), fill_value_, dword_num_loaded);
}
}
/**
@@ -313,9 +318,6 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
break;
}
}
if (is_fill_)
fill_with_default(vmm, fill_value_, load_size / 4);
}
/**
@@ -407,9 +409,6 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak
break;
}
}
if (is_fill_)
fill_with_default(vmm, fill_value_, load_size);
}
/**
@@ -524,9 +523,6 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak
break;
}
}
if (is_fill_)
fill_with_default(vmm, fill_value_, load_size / 2);
}
template <typename Vmm>

View File

@@ -414,6 +414,29 @@ const auto Mvn2DTrans = ::testing::Combine(
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2DTrans, MvnLayerCPUTest, Mvn2DTrans, MvnLayerCPUTest::getTestCaseName);
// no transformed with small spatial dim and i8 data and no fusion to cover model use case
const std::vector<InputShape> inputShapesSmallSpatial = {
{ {}, {{4, 1}}},
{ {}, {{2, 2}}},
{ {}, {{1, 2, 1}}},
{ {}, {{3, 1, 1, 1}}},
};
const auto MvnSmallSpatial = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapesSmallSpatial),
::testing::Values(ElementType::i8),
::testing::ValuesIn(emptyReductionAxes),
::testing::Values(false),
::testing::Values(false),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::Values(emptyFusingSpec),
::testing::Values(ElementType::i8),
::testing::Values(ElementType::f32));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_MvnSmallSpatial, MvnLayerCPUTest, MvnSmallSpatial, MvnLayerCPUTest::getTestCaseName);
// Static shape test for some specific fusing parameters in fusingParamsSetStaticShape
const std::vector<ov::Shape> inputShapesStatic_2D = {