[IE VPU] Enable s32->u8 conversion (#699)

This commit is contained in:
Maksim Doronin 2020-06-02 12:20:06 +03:00 committed by GitHub
parent 278868b7a1
commit daaeaa5881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 3 deletions

View File

@ -99,6 +99,7 @@ const SupportedConversionSet ConvertStage::expectedTypes = {
{DataType::FP32, DataType::FP16}, {DataType::FP32, DataType::FP16},
{DataType::S32, DataType::FP16}, {DataType::S32, DataType::FP16},
{DataType::FP16, DataType::S32}, {DataType::FP16, DataType::S32},
{DataType::S32, DataType::U8},
}; };
} // namespace } // namespace

View File

@ -134,7 +134,8 @@ std::vector<SizeVector> inputsDims4D = {
std::vector<PrecisionPair> precisionsIO = { std::vector<PrecisionPair> precisionsIO = {
{Precision::U8, Precision::FP16}, {Precision::U8, Precision::FP16},
{Precision::FP32, Precision::FP16}, {Precision::FP32, Precision::FP16},
{Precision::FP16, Precision::FP32} {Precision::FP16, Precision::FP32},
{Precision::I32, Precision::U8}
}; };
std::vector<Precision> withFP16Precisions = { std::vector<Precision> withFP16Precisions = {

View File

@ -2558,6 +2558,8 @@ void ref_convert(const InferenceEngine::Blob::Ptr &src,
} else if (srcPrecision == Precision::I32 && dstPrecision == Precision::FP16) { } else if (srcPrecision == Precision::I32 && dstPrecision == Precision::FP16) {
dst->buffer().as<ie_fp16 *>()[i] = PrecisionUtils::f32tof16( dst->buffer().as<ie_fp16 *>()[i] = PrecisionUtils::f32tof16(
static_cast<float >(src->cbuffer().as<int32_t *>()[i])); static_cast<float >(src->cbuffer().as<int32_t *>()[i]));
} else if (srcPrecision == Precision::I32 && dstPrecision == Precision::U8) {
dst->buffer().as<uint8_t *>()[i] = static_cast<uint8_t>(src->cbuffer().as<int32_t *>()[i]);
} else { } else {
THROW_IE_EXCEPTION << "Unsupported input or output precision"; THROW_IE_EXCEPTION << "Unsupported input or output precision";
} }

View File

@ -68,6 +68,8 @@ BufferWrapper::BufferWrapper(const Blob::Ptr& blob, Precision _precision) : prec
fp32_ptr = blob->buffer().as<float*>(); fp32_ptr = blob->buffer().as<float*>();
} else if (precision == Precision::I32) { } else if (precision == Precision::I32) {
i32_ptr = blob->buffer().as<int32_t*>(); i32_ptr = blob->buffer().as<int32_t*>();
} else if (precision == Precision::U8) {
u8_ptr = blob->buffer().as<uint8_t*>();
} else { } else {
THROW_IE_EXCEPTION << "Unsupported precision for compare: " << precision; THROW_IE_EXCEPTION << "Unsupported precision for compare: " << precision;
} }
@ -78,6 +80,8 @@ float BufferWrapper::operator[](size_t index) {
return PrecisionUtils::f16tof32(fp16_ptr[index]); return PrecisionUtils::f16tof32(fp16_ptr[index]);
} else if (precision == Precision::I32) { } else if (precision == Precision::I32) {
return i32_ptr[index]; return i32_ptr[index];
} else if (precision == Precision::U8) {
return u8_ptr[index];
} }
return fp32_ptr[index]; return fp32_ptr[index];
} }
@ -87,8 +91,9 @@ void BufferWrapper::insert(size_t index, float value) {
fp16_ptr[index] = PrecisionUtils::f32tof16(value); fp16_ptr[index] = PrecisionUtils::f32tof16(value);
} else if (precision == Precision::I32) { } else if (precision == Precision::I32) {
i32_ptr[index] = value; i32_ptr[index] = value;
} } else if (precision == Precision::U8) {
else { u8_ptr[index] = value;
} else {
fp32_ptr[index] = value; fp32_ptr[index] = value;
} }
} }

View File

@ -115,6 +115,7 @@ class BufferWrapper {
InferenceEngine::ie_fp16 *fp16_ptr; InferenceEngine::ie_fp16 *fp16_ptr;
float *fp32_ptr; float *fp32_ptr;
int32_t *i32_ptr; int32_t *i32_ptr;
uint8_t *u8_ptr;
public: public:
explicit BufferWrapper(const InferenceEngine::Blob::Ptr &blob); explicit BufferWrapper(const InferenceEngine::Blob::Ptr &blob);