[ARM CPU] Fix tests for eltwise layer (#16917)
This commit is contained in:
committed by
GitHub
parent
5bded05ae6
commit
d00731c0ab
@@ -11,9 +11,11 @@
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class AclEltwiseExecutor : public EltwiseExecutor {
|
||||
public:
|
||||
AclEltwiseExecutor(const ExecutorContext::CPtr context);
|
||||
explicit AclEltwiseExecutor(const ExecutorContext::CPtr context);
|
||||
|
||||
bool init(const EltwiseAttrs& eltwiseAttrs,
|
||||
const std::vector<MemoryDescPtr>& srcDescs,
|
||||
@@ -39,62 +41,98 @@ public:
|
||||
bool isSupported(const EltwiseAttrs& eltwiseAttrs,
|
||||
const std::vector<MemoryDescPtr>& srcDescs,
|
||||
const std::vector<MemoryDescPtr>& dstDescs) const override {
|
||||
auto checkPrecision = [&srcDescs, &dstDescs](std::vector<Precision> srcVecPrc, Precision dstPrc) -> bool {
|
||||
for (int i = 0; i < srcDescs.size(); i++) {
|
||||
if (srcDescs[i]->getPrecision() != srcVecPrc[i]) return false;
|
||||
}
|
||||
if (dstDescs[0]->getPrecision() != dstPrc) { return false; }
|
||||
return true;
|
||||
};
|
||||
|
||||
switch (eltwiseAttrs.algorithm) {
|
||||
case Algorithm::EltwiseAdd:
|
||||
case Algorithm::EltwiseMultiply:
|
||||
case Algorithm::EltwiseSubtract:
|
||||
case Algorithm::EltwiseSqrt:
|
||||
case Algorithm::EltwiseDivide:
|
||||
case Algorithm::EltwiseRelu:
|
||||
case Algorithm::EltwiseGeluErf:
|
||||
case Algorithm::EltwiseElu:
|
||||
case Algorithm::EltwiseTanh:
|
||||
case Algorithm::EltwiseSigmoid:
|
||||
// case Algorithm::EltwisePowerDynamic: // TODO: ACL version doesn't work https://github.com/ARM-software/ComputeLibrary/issues/1047
|
||||
case Algorithm::EltwiseSoftRelu:
|
||||
case Algorithm::EltwiseClamp:
|
||||
case Algorithm::EltwiseSwish:
|
||||
case Algorithm::EltwisePrelu:
|
||||
case Algorithm::EltwiseHswish:
|
||||
if (!(checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Algorithm::EltwiseAbs:
|
||||
case Algorithm::EltwiseExp:
|
||||
case Algorithm::EltwiseLog:
|
||||
if (!(checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
|
||||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Algorithm::EltwiseMaximum:
|
||||
case Algorithm::EltwiseMinimum:
|
||||
case Algorithm::EltwiseSquaredDifference:
|
||||
case Algorithm::EltwisePowerDynamic:
|
||||
if (!(checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
|
||||
checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
|
||||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Algorithm::EltwiseAdd:
|
||||
case Algorithm::EltwiseSubtract:
|
||||
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
|
||||
checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
|
||||
checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
|
||||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Algorithm::EltwiseMultiply:
|
||||
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
|
||||
checkPrecision({Precision::U8, Precision::U8}, Precision::I16) ||
|
||||
checkPrecision({Precision::U8, Precision::I16}, Precision::I16) ||
|
||||
checkPrecision({Precision::I16, Precision::U8}, Precision::I16) ||
|
||||
checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
|
||||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
// ACL supports only U8 precision on output for comparison operations
|
||||
case Algorithm::EltwiseEqual:
|
||||
case Algorithm::EltwiseNotEqual:
|
||||
case Algorithm::EltwiseGreater:
|
||||
case Algorithm::EltwiseGreaterEqual:
|
||||
case Algorithm::EltwiseLess:
|
||||
case Algorithm::EltwiseLessEqual:
|
||||
case Algorithm::EltwiseRelu:
|
||||
case Algorithm::EltwiseGeluErf:
|
||||
case Algorithm::EltwiseElu:
|
||||
case Algorithm::EltwiseTanh:
|
||||
case Algorithm::EltwiseSigmoid:
|
||||
case Algorithm::EltwiseAbs:
|
||||
case Algorithm::EltwiseSqrt:
|
||||
case Algorithm::EltwiseSoftRelu:
|
||||
case Algorithm::EltwiseExp:
|
||||
case Algorithm::EltwiseClamp:
|
||||
case Algorithm::EltwiseSwish:
|
||||
case Algorithm::EltwisePrelu:
|
||||
case Algorithm::EltwiseHswish:
|
||||
case Algorithm::EltwiseLog:
|
||||
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
|
||||
checkPrecision({Precision::I16, Precision::I16}, Precision::U8) ||
|
||||
checkPrecision({Precision::I32, Precision::I32}, Precision::U8) ||
|
||||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::U8) ||
|
||||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::U8))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
// ACL supports only U8 precision on output for comparison operations
|
||||
if (one_of(eltwiseAttrs.algorithm, Algorithm::EltwiseEqual, Algorithm::EltwiseNotEqual, Algorithm::EltwiseGreater,
|
||||
Algorithm::EltwiseGreaterEqual, Algorithm::EltwiseLess, Algorithm::EltwiseLessEqual)) {
|
||||
if (dstDescs[0]->getPrecision() != InferenceEngine::Precision::U8) {
|
||||
for (const auto & srcDesc : srcDescs) {
|
||||
if (getAclDataLayoutByMemoryDesc(srcDesc) == arm_compute::DataLayout::UNKNOWN)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (const auto &srcD : srcDescs) {
|
||||
for (const auto &dstD : dstDescs) {
|
||||
if ((srcD->getPrecision() != InferenceEngine::Precision::FP32 &&
|
||||
srcD->getPrecision() != InferenceEngine::Precision::FP16) ||
|
||||
srcD->getPrecision() != dstD->getPrecision())
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < srcDescs.size(); i++) {
|
||||
if (getAclDataLayoutByMemoryDesc(srcDescs[i]) == arm_compute::DataLayout::UNKNOWN)
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < dstDescs.size(); i++) {
|
||||
if (getAclDataLayoutByMemoryDesc(dstDescs[i]) == arm_compute::DataLayout::UNKNOWN)
|
||||
for (const auto & dstDesc : dstDescs) {
|
||||
if (getAclDataLayoutByMemoryDesc(dstDesc) == arm_compute::DataLayout::UNKNOWN)
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -225,7 +225,6 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest.*)");
|
||||
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckWithSecondaryPropertiesDoubleTest.*)");
|
||||
}
|
||||
retVector.emplace_back(R"(smoke_Decomposition_(3|4)D/Mvn6LayerTest.CompareWithRefs.*)");
|
||||
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_CeilRounding/PoolingLayerTest.CompareWithRefs.*)");
|
||||
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
|
||||
retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)");
|
||||
|
||||
Reference in New Issue
Block a user