[CPU] Fixed Softmax and TopK nodes initilization for ARM devices (#16950)

This commit is contained in:
Gorokhov Dmitriy 2023-04-14 22:13:42 +04:00 committed by GitHub
parent 7ce40996e5
commit cc6fd80d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 14 deletions

View File

@ -184,6 +184,7 @@ void SoftMax::prepareParams() {
primitive_desc_iterator itpd = *desc;
auto itpd_first = itpd;
while (itpd) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == key.implType ||
@ -195,8 +196,10 @@ void SoftMax::prepareParams() {
prim_desc = itpd.get();
break;
}
if (!itpd.next_impl())
return nullptr;
if (!itpd.next_impl()) {
prim_desc = itpd_first.get();
break;
}
}
return std::make_shared<DnnlExecutor>(prim_desc);
};

View File

@ -1954,15 +1954,6 @@ bool TopK::needPrepareParams() const {
}
void TopK::preset_params() {
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
layout = TopKLayoutType::topk_ncsp;
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
layout = TopKLayoutType::topk_nspc;
} else {
layout = TopKLayoutType::topk_blocked;
}
auto selectedPD = getSelectedPrimitiveDescriptor();
auto data_type = DnnlExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].getMemDesc()->getPrecision());
data_size = DnnlExtensionUtils::sizeOfDataType(data_type);
@ -2073,6 +2064,15 @@ void TopK::prepareParams() {
}
void TopK::createPrimitive() {
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
layout = TopKLayoutType::topk_ncsp;
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
layout = TopKLayoutType::topk_nspc;
} else {
layout = TopKLayoutType::topk_blocked;
}
if (inputShapesDefined() && isExecutable()) {
if (needPrepareParams())
prepareParams();
@ -2108,7 +2108,6 @@ void TopK::createPrimitive() {
jcp.bitonic_k_idx_cnt = 0;
if (algorithm == TopKAlgorithm::topk_bitonic_sort) {
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
size_t src_count = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
vec_process_ptr.resize(src_count * data_size);
vec_process_idx_ptr.resize(src_count * sizeof(int32_t));

View File

@ -240,8 +240,6 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(smoke_INTEL_CPU_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
retVector.emplace_back(R"(smoke_INTEL_CPU_TestsSelect_none/SelectLayerTest.CompareWithRefImpl/COND=BOOL.*)");
retVector.emplace_back(R"(smoke_INTEL_CPU_TestsSelect_numpy/SelectLayerTest.CompareWithRefImpl/COND=BOOL.*)");
retVector.emplace_back(R"(smoke_SoftMax(2|4)D_dynamic/SoftMax8LayerTest.CompareWithRefs/NetType=f32_InType=undefined_OutType=undefined.*)");
retVector.emplace_back(R"(smoke_TopK/TopKLayerTest.CompareWithRefsDynamicBath.*)");
retVector.emplace_back(R"(smoke_Snippets.*)");
retVector.emplace_back(R"(smoke_Quantized.*)");
retVector.emplace_back(R"(smoke_NegativeQuantizedMatMulMultiplyFusion.*)");