[ARM CPU] Avg Pooling, ROI Pooling fix for fp16 precision (#20658)
This commit is contained in:
committed by
GitHub
parent
4e41678502
commit
9decbb538b
@@ -322,7 +322,7 @@ void Pooling::getSupportedDescriptors() {
|
||||
|
||||
// WA: LPT transformation has WA which allows average pooling has I8/U8 output precision instead of FP32,
|
||||
// so we explicitly set output precision as FP32
|
||||
if (outputPrecision != Precision::I8 && inputPrecision != Precision::BF16) {
|
||||
if (!one_of(outputPrecision, Precision::I8, Precision::BF16, Precision::FP16)) {
|
||||
if (getAlgorithm() == Algorithm::PoolingMax) {
|
||||
// oneDNN supports only equal precisions for input and output
|
||||
outputPrecision = inputPrecision;
|
||||
@@ -330,7 +330,7 @@ void Pooling::getSupportedDescriptors() {
|
||||
outputPrecision = Precision::FP32;
|
||||
}
|
||||
}
|
||||
if (inputPrecision == Precision::BF16) {
|
||||
if (one_of(inputPrecision, Precision::BF16, Precision::FP16)) {
|
||||
outputPrecision = inputPrecision;
|
||||
}
|
||||
|
||||
@@ -351,7 +351,7 @@ void Pooling::getSupportedDescriptors() {
|
||||
|
||||
if (inputPrecision == Precision::I8 || inputPrecision == Precision::U8) {
|
||||
// We have to extend i8i8_pooling_fwd_t from oneDNN to support BF16 output data type
|
||||
if (outputDataType == memory::data_type::bf16)
|
||||
if (one_of(outputDataType, memory::data_type::bf16, memory::data_type::f16))
|
||||
outputDataType = memory::data_type::f32;
|
||||
// i8 layers supports only ndhwc and nhwc layouts
|
||||
const auto in_candidate = std::make_shared<DnnlBlockedMemoryDesc>(parentShape, inputDataType, inputRank == 3 ?
|
||||
|
||||
@@ -434,13 +434,6 @@ void ROIPooling::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
refParams.src_prc = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
if (!mayiuse(avx512_core)) {
|
||||
if (refParams.src_prc == Precision::BF16)
|
||||
refParams.src_prc = Precision::FP32;
|
||||
}
|
||||
|
||||
auto format = mayiuse(avx512_core) ? LayoutType::nCsp16c : LayoutType::nCsp8c;
|
||||
impl_desc_type impl_type;
|
||||
if (mayiuse(cpu::x64::avx512_core)) {
|
||||
@@ -453,6 +446,17 @@ void ROIPooling::initSupportedPrimitiveDescriptors() {
|
||||
impl_type = impl_desc_type::ref;
|
||||
}
|
||||
|
||||
refParams.src_prc = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
if (!mayiuse(avx512_core)) {
|
||||
if (refParams.src_prc == Precision::BF16)
|
||||
refParams.src_prc = Precision::FP32;
|
||||
}
|
||||
|
||||
if (impl_type != impl_desc_type::ref && refParams.src_prc == Precision::FP16) {
|
||||
refParams.src_prc = Precision::FP32;
|
||||
}
|
||||
|
||||
addSupportedPrimDesc({{format, refParams.src_prc},
|
||||
{LayoutType::ncsp, refParams.src_prc}},
|
||||
{{format, refParams.src_prc}},
|
||||
@@ -826,7 +830,8 @@ std::shared_ptr<ROIPooling::ROIPoolingExecutor> ROIPooling::ROIPoolingExecutor::
|
||||
|
||||
OV_SWITCH(intel_cpu, ROIPoolingExecutorCreation, ctx, jpp.src_prc,
|
||||
OV_CASE(Precision::FP32, float),
|
||||
OV_CASE(Precision::BF16, bfloat16_t))
|
||||
OV_CASE(Precision::BF16, bfloat16_t),
|
||||
OV_CASE(Precision::FP16, float16_t))
|
||||
|
||||
return ctx.executor;
|
||||
}
|
||||
|
||||
@@ -229,15 +229,9 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
|
||||
#if defined(OV_CPU_ARM_ENABLE_FP16)
|
||||
// Issue: 123019
|
||||
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_CeilRounding.*modelType=f16.*)");
|
||||
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_FloorRounding_5Dinput/PoolingLayerTest.*modelType=f16.*)");
|
||||
retVector.emplace_back(R"(smoke_AvgPool_SameUpperPad_FloorRounding_5Dinput/PoolingLayerTest.*modelType=f16.*)");
|
||||
retVector.emplace_back(R"(smoke_AvgPool_SameLowerPad_CeilRounding_5Dinput/PoolingLayerTest.*modelType=f16.*)");
|
||||
retVector.emplace_back(R"(smoke_CompareWithRefs_Mvn.*INFERENCE_PRECISION_HINT=f16.*)");
|
||||
retVector.emplace_back(R"(smoke_staticShapes4D.*INFERENCE_PRECISION_HINT=f16.*)");
|
||||
retVector.emplace_back(R"(smoke_dynamicShapes4D.*INFERENCE_PRECISION_HINT=f16.*)");
|
||||
// Issue: 123064
|
||||
retVector.emplace_back(R"(smoke_TestsROIPooling_.*/ROIPoolingLayerTest.*modelType=f16.*)");
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user