[CPU][ARM] Enable FP16 precision for CumSum node (#19176)
This commit is contained in:
parent
600c2d8283
commit
680333b2db
@ -12,6 +12,7 @@
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include "cum_sum.h"
|
||||
#include "utils/bfloat16.hpp"
|
||||
#include "openvino/core/type/float16.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
@ -72,7 +73,10 @@ void CumSum::initSupportedPrimitiveDescriptors() {
|
||||
return;
|
||||
|
||||
dataPrecision = getOriginalInputPrecisionAtPort(CUM_SUM_DATA);
|
||||
if (!one_of(dataPrecision, Precision::I8, Precision::U8, Precision::I16, Precision::BF16, Precision::I32, Precision::FP32, Precision::I64, Precision::U64))
|
||||
if (!one_of(dataPrecision,
|
||||
Precision::I8, Precision::U8,
|
||||
Precision::I16, Precision::I32, Precision::I64, Precision::U64,
|
||||
Precision::BF16, Precision::FP16, Precision::FP32))
|
||||
IE_THROW() << errorPrefix << " has unsupported 'data' input precision: " << dataPrecision.name();
|
||||
|
||||
if (inputShapes.size() == numOfInputs) {
|
||||
@ -101,6 +105,7 @@ void CumSum::execute(dnnl::stream strm) {
|
||||
OV_CASE(Precision::U8, uint8_t),
|
||||
OV_CASE(Precision::I16, int16_t),
|
||||
OV_CASE(Precision::BF16, bfloat16_t),
|
||||
OV_CASE(Precision::FP16, ov::float16),
|
||||
OV_CASE(Precision::I32, int32_t),
|
||||
OV_CASE(Precision::FP32, float),
|
||||
OV_CASE(Precision::I64, int64_t),
|
||||
|
Loading…
Reference in New Issue
Block a user