[ARM CPU] Fix tests for eltwise layer (#16917)
This commit is contained in:
committed by
GitHub
parent
5bded05ae6
commit
d00731c0ab
@@ -11,9 +11,11 @@
|
|||||||
namespace ov {
|
namespace ov {
|
||||||
namespace intel_cpu {
|
namespace intel_cpu {
|
||||||
|
|
||||||
|
using namespace InferenceEngine;
|
||||||
|
|
||||||
class AclEltwiseExecutor : public EltwiseExecutor {
|
class AclEltwiseExecutor : public EltwiseExecutor {
|
||||||
public:
|
public:
|
||||||
AclEltwiseExecutor(const ExecutorContext::CPtr context);
|
explicit AclEltwiseExecutor(const ExecutorContext::CPtr context);
|
||||||
|
|
||||||
bool init(const EltwiseAttrs& eltwiseAttrs,
|
bool init(const EltwiseAttrs& eltwiseAttrs,
|
||||||
const std::vector<MemoryDescPtr>& srcDescs,
|
const std::vector<MemoryDescPtr>& srcDescs,
|
||||||
@@ -39,62 +41,98 @@ public:
|
|||||||
bool isSupported(const EltwiseAttrs& eltwiseAttrs,
|
bool isSupported(const EltwiseAttrs& eltwiseAttrs,
|
||||||
const std::vector<MemoryDescPtr>& srcDescs,
|
const std::vector<MemoryDescPtr>& srcDescs,
|
||||||
const std::vector<MemoryDescPtr>& dstDescs) const override {
|
const std::vector<MemoryDescPtr>& dstDescs) const override {
|
||||||
|
auto checkPrecision = [&srcDescs, &dstDescs](std::vector<Precision> srcVecPrc, Precision dstPrc) -> bool {
|
||||||
|
for (int i = 0; i < srcDescs.size(); i++) {
|
||||||
|
if (srcDescs[i]->getPrecision() != srcVecPrc[i]) return false;
|
||||||
|
}
|
||||||
|
if (dstDescs[0]->getPrecision() != dstPrc) { return false; }
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
switch (eltwiseAttrs.algorithm) {
|
switch (eltwiseAttrs.algorithm) {
|
||||||
case Algorithm::EltwiseAdd:
|
case Algorithm::EltwiseSqrt:
|
||||||
case Algorithm::EltwiseMultiply:
|
|
||||||
case Algorithm::EltwiseSubtract:
|
|
||||||
case Algorithm::EltwiseDivide:
|
case Algorithm::EltwiseDivide:
|
||||||
|
case Algorithm::EltwiseRelu:
|
||||||
|
case Algorithm::EltwiseGeluErf:
|
||||||
|
case Algorithm::EltwiseElu:
|
||||||
|
case Algorithm::EltwiseTanh:
|
||||||
|
case Algorithm::EltwiseSigmoid:
|
||||||
|
// case Algorithm::EltwisePowerDynamic: // TODO: ACL version doesn't work https://github.com/ARM-software/ComputeLibrary/issues/1047
|
||||||
|
case Algorithm::EltwiseSoftRelu:
|
||||||
|
case Algorithm::EltwiseClamp:
|
||||||
|
case Algorithm::EltwiseSwish:
|
||||||
|
case Algorithm::EltwisePrelu:
|
||||||
|
case Algorithm::EltwiseHswish:
|
||||||
|
if (!(checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||||
|
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Algorithm::EltwiseAbs:
|
||||||
|
case Algorithm::EltwiseExp:
|
||||||
|
case Algorithm::EltwiseLog:
|
||||||
|
if (!(checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
|
||||||
|
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||||
|
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case Algorithm::EltwiseMaximum:
|
case Algorithm::EltwiseMaximum:
|
||||||
case Algorithm::EltwiseMinimum:
|
case Algorithm::EltwiseMinimum:
|
||||||
case Algorithm::EltwiseSquaredDifference:
|
case Algorithm::EltwiseSquaredDifference:
|
||||||
case Algorithm::EltwisePowerDynamic:
|
if (!(checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
|
||||||
|
checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
|
||||||
|
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||||
|
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Algorithm::EltwiseAdd:
|
||||||
|
case Algorithm::EltwiseSubtract:
|
||||||
|
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
|
||||||
|
checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
|
||||||
|
checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
|
||||||
|
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||||
|
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Algorithm::EltwiseMultiply:
|
||||||
|
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
|
||||||
|
checkPrecision({Precision::U8, Precision::U8}, Precision::I16) ||
|
||||||
|
checkPrecision({Precision::U8, Precision::I16}, Precision::I16) ||
|
||||||
|
checkPrecision({Precision::I16, Precision::U8}, Precision::I16) ||
|
||||||
|
checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
|
||||||
|
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
|
||||||
|
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
// ACL supports only U8 precision on output for comparison operations
|
||||||
case Algorithm::EltwiseEqual:
|
case Algorithm::EltwiseEqual:
|
||||||
case Algorithm::EltwiseNotEqual:
|
case Algorithm::EltwiseNotEqual:
|
||||||
case Algorithm::EltwiseGreater:
|
case Algorithm::EltwiseGreater:
|
||||||
case Algorithm::EltwiseGreaterEqual:
|
case Algorithm::EltwiseGreaterEqual:
|
||||||
case Algorithm::EltwiseLess:
|
case Algorithm::EltwiseLess:
|
||||||
case Algorithm::EltwiseLessEqual:
|
case Algorithm::EltwiseLessEqual:
|
||||||
case Algorithm::EltwiseRelu:
|
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
|
||||||
case Algorithm::EltwiseGeluErf:
|
checkPrecision({Precision::I16, Precision::I16}, Precision::U8) ||
|
||||||
case Algorithm::EltwiseElu:
|
checkPrecision({Precision::I32, Precision::I32}, Precision::U8) ||
|
||||||
case Algorithm::EltwiseTanh:
|
checkPrecision({Precision::FP16, Precision::FP16}, Precision::U8) ||
|
||||||
case Algorithm::EltwiseSigmoid:
|
checkPrecision({Precision::FP32, Precision::FP32}, Precision::U8))) {
|
||||||
case Algorithm::EltwiseAbs:
|
return false;
|
||||||
case Algorithm::EltwiseSqrt:
|
}
|
||||||
case Algorithm::EltwiseSoftRelu:
|
|
||||||
case Algorithm::EltwiseExp:
|
|
||||||
case Algorithm::EltwiseClamp:
|
|
||||||
case Algorithm::EltwiseSwish:
|
|
||||||
case Algorithm::EltwisePrelu:
|
|
||||||
case Algorithm::EltwiseHswish:
|
|
||||||
case Algorithm::EltwiseLog:
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ACL supports only U8 precision on output for comparison operations
|
for (const auto & srcDesc : srcDescs) {
|
||||||
if (one_of(eltwiseAttrs.algorithm, Algorithm::EltwiseEqual, Algorithm::EltwiseNotEqual, Algorithm::EltwiseGreater,
|
if (getAclDataLayoutByMemoryDesc(srcDesc) == arm_compute::DataLayout::UNKNOWN)
|
||||||
Algorithm::EltwiseGreaterEqual, Algorithm::EltwiseLess, Algorithm::EltwiseLessEqual)) {
|
|
||||||
if (dstDescs[0]->getPrecision() != InferenceEngine::Precision::U8) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (const auto &srcD : srcDescs) {
|
for (const auto & dstDesc : dstDescs) {
|
||||||
for (const auto &dstD : dstDescs) {
|
if (getAclDataLayoutByMemoryDesc(dstDesc) == arm_compute::DataLayout::UNKNOWN)
|
||||||
if ((srcD->getPrecision() != InferenceEngine::Precision::FP32 &&
|
|
||||||
srcD->getPrecision() != InferenceEngine::Precision::FP16) ||
|
|
||||||
srcD->getPrecision() != dstD->getPrecision())
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < srcDescs.size(); i++) {
|
|
||||||
if (getAclDataLayoutByMemoryDesc(srcDescs[i]) == arm_compute::DataLayout::UNKNOWN)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < dstDescs.size(); i++) {
|
|
||||||
if (getAclDataLayoutByMemoryDesc(dstDescs[i]) == arm_compute::DataLayout::UNKNOWN)
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -225,7 +225,6 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest.*)");
|
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest.*)");
|
||||||
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckWithSecondaryPropertiesDoubleTest.*)");
|
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckWithSecondaryPropertiesDoubleTest.*)");
|
||||||
}
|
}
|
||||||
retVector.emplace_back(R"(smoke_Decomposition_(3|4)D/Mvn6LayerTest.CompareWithRefs.*)");
|
|
||||||
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_CeilRounding/PoolingLayerTest.CompareWithRefs.*)");
|
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_CeilRounding/PoolingLayerTest.CompareWithRefs.*)");
|
||||||
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
|
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
|
||||||
retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)");
|
retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)");
|
||||||
|
|||||||
Reference in New Issue
Block a user