[IE VPU] Enable s32->u8 conversion (#699)

This commit is contained in:
Maksim Doronin 2020-06-02 12:20:06 +03:00 committed by GitHub
parent 278868b7a1
commit daaeaa5881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 3 deletions

View File

@ -99,6 +99,7 @@ const SupportedConversionSet ConvertStage::expectedTypes = {
{DataType::FP32, DataType::FP16},
{DataType::S32, DataType::FP16},
{DataType::FP16, DataType::S32},
{DataType::S32, DataType::U8},
};
} // namespace

View File

@ -134,7 +134,8 @@ std::vector<SizeVector> inputsDims4D = {
std::vector<PrecisionPair> precisionsIO = {
{Precision::U8, Precision::FP16},
{Precision::FP32, Precision::FP16},
{Precision::FP16, Precision::FP32}
{Precision::FP16, Precision::FP32},
{Precision::I32, Precision::U8}
};
std::vector<Precision> withFP16Precisions = {

View File

@ -2558,6 +2558,8 @@ void ref_convert(const InferenceEngine::Blob::Ptr &src,
} else if (srcPrecision == Precision::I32 && dstPrecision == Precision::FP16) {
dst->buffer().as<ie_fp16 *>()[i] = PrecisionUtils::f32tof16(
static_cast<float >(src->cbuffer().as<int32_t *>()[i]));
} else if (srcPrecision == Precision::I32 && dstPrecision == Precision::U8) {
dst->buffer().as<uint8_t *>()[i] = static_cast<uint8_t>(src->cbuffer().as<int32_t *>()[i]);
} else {
THROW_IE_EXCEPTION << "Unsupported input or output precision";
}

View File

@ -68,6 +68,8 @@ BufferWrapper::BufferWrapper(const Blob::Ptr& blob, Precision _precision) : prec
fp32_ptr = blob->buffer().as<float*>();
} else if (precision == Precision::I32) {
i32_ptr = blob->buffer().as<int32_t*>();
} else if (precision == Precision::U8) {
u8_ptr = blob->buffer().as<uint8_t*>();
} else {
THROW_IE_EXCEPTION << "Unsupported precision for compare: " << precision;
}
@ -78,6 +80,8 @@ float BufferWrapper::operator[](size_t index) {
return PrecisionUtils::f16tof32(fp16_ptr[index]);
} else if (precision == Precision::I32) {
return i32_ptr[index];
} else if (precision == Precision::U8) {
return u8_ptr[index];
}
return fp32_ptr[index];
}
@ -87,8 +91,9 @@ void BufferWrapper::insert(size_t index, float value) {
fp16_ptr[index] = PrecisionUtils::f32tof16(value);
} else if (precision == Precision::I32) {
i32_ptr[index] = value;
}
else {
} else if (precision == Precision::U8) {
u8_ptr[index] = value;
} else {
fp32_ptr[index] = value;
}
}

View File

@ -115,6 +115,7 @@ class BufferWrapper {
InferenceEngine::ie_fp16 *fp16_ptr;
float *fp32_ptr;
int32_t *i32_ptr;
uint8_t *u8_ptr;
public:
explicit BufferWrapper(const InferenceEngine::Blob::Ptr &blob);