refactor convertOutputPrecision (#5338)
Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
parent
cf6acfde78
commit
6a850b1e7b
@ -288,16 +288,6 @@ std::shared_ptr<ngraph::Node> getNodeSharedPtr(const ngraph::NodeTypeInfo &type_
|
||||
NGRAPH_UNREACHABLE("supported opsets does not contain op with name: ", type_info.name, " version: ", type_info.version);
|
||||
}
|
||||
|
||||
template <typename fromPrec, typename toPrec>
|
||||
std::vector<std::uint8_t> convertPrecision(const std::vector<std::uint8_t> &buffer, const size_t elementsCount, const size_t elementSize) {
|
||||
std::vector<std::uint8_t> convertedData(elementsCount * elementSize);
|
||||
const fromPrec *src = reinterpret_cast<const fromPrec *>(buffer.data());
|
||||
toPrec *dst = reinterpret_cast<toPrec *>(convertedData.data());
|
||||
for (size_t i = 0; i < elementsCount; i++)
|
||||
dst[i] = static_cast<toPrec>(src[i]);
|
||||
return convertedData;
|
||||
}
|
||||
|
||||
bool is_tensor_iterator_exist(const std::shared_ptr<ngraph::Function> & func) {
|
||||
const auto& ops = func->get_ops();
|
||||
for (const auto& node : ops) {
|
||||
@ -309,429 +299,124 @@ bool is_tensor_iterator_exist(const std::shared_ptr<ngraph::Function> & func) {
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <element::Type_t FromType, element::Type_t ToType>
|
||||
std::vector<std::uint8_t> convertPrecision(const std::vector<std::uint8_t> &buffer, const size_t elementsCount) {
|
||||
using fromPrec = fundamental_type_for<FromType>;
|
||||
using toPrec = fundamental_type_for<ToType>;
|
||||
|
||||
NGRAPH_CHECK(buffer.size() >= elementsCount * sizeof(fromPrec), "avoid buffer overflow");
|
||||
|
||||
constexpr auto elementSize = sizeof(toPrec);
|
||||
std::vector<std::uint8_t> convertedData(elementsCount * elementSize);
|
||||
|
||||
const fromPrec *src = reinterpret_cast<const fromPrec *>(buffer.data());
|
||||
toPrec *dst = reinterpret_cast<toPrec *>(convertedData.data());
|
||||
for (size_t i = 0; i < elementsCount; i++) {
|
||||
dst[i] = static_cast<toPrec>(src[i]);
|
||||
}
|
||||
return convertedData;
|
||||
}
|
||||
|
||||
template <element::Type_t FromType>
|
||||
std::vector<std::uint8_t> convertPrecisionFrom(const std::vector<std::uint8_t> &output, const element::Type_t &toPrecision, const size_t elementsCount) {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::boolean: {
|
||||
return convertPrecision<FromType, element::Type_t::boolean>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<FromType, element::Type_t::bf16>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<FromType, element::Type_t::f16>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<FromType, element::Type_t::f32>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::f64: {
|
||||
return convertPrecision<FromType, element::Type_t::f64>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<FromType, element::Type_t::i8>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<FromType, element::Type_t::i16>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<FromType, element::Type_t::i32>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<FromType, element::Type_t::i64>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<FromType, element::Type_t::u8>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<FromType, element::Type_t::u16>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::u32: {
|
||||
return convertPrecision<FromType, element::Type_t::u32>(output, elementsCount);
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<FromType, element::Type_t::u64>(output, elementsCount);
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
std::string("convertOutputPrecision can't convert from: ") + element::Type(FromType).get_type_name() +
|
||||
" to: " + element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
std::vector<std::uint8_t> convertOutputPrecision(const std::vector<std::uint8_t> &output,
|
||||
const element::Type_t &fromPrecision,
|
||||
const element::Type_t &toPrecision,
|
||||
const size_t elementsCount) {
|
||||
switch (fromPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<uint8_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<uint8_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<uint8_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<uint8_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<uint8_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<uint8_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<uint8_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<uint8_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<uint8_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<uint8_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<uint16_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<uint16_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<uint16_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<uint16_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<uint16_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<uint16_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<uint16_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<uint16_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<uint16_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<uint16_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<int8_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<int8_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<int8_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<int8_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<int8_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<int8_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<int8_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<int8_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<int8_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<int8_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<int16_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<int16_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<int16_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<int16_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<int16_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<int16_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<int16_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<int16_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<int16_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<int16_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<int32_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<int32_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<int32_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<int32_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<int32_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<int32_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<int32_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<int32_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<int32_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<int32_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<int64_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<int64_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<int64_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<int64_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<int64_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<int64_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<int64_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<int64_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<int64_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<int64_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<uint64_t, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<uint64_t, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<uint64_t, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<uint64_t, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<uint64_t, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<uint64_t, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<uint64_t, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<uint64_t, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<uint64_t, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<uint64_t, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<float, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<float, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<float, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<float, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<float, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<float, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<float, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<float, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<float, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<float, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::boolean: {
|
||||
return convertPrecision<float, char>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::boolean: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<char, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<char, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<char, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<char, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<char, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<char, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<char, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<char, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<char, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<char, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<ngraph::float16, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<ngraph::float16, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<ngraph::float16, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<ngraph::float16, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<ngraph::float16, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<ngraph::float16, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<ngraph::float16, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<ngraph::float16, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<ngraph::float16, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<ngraph::float16, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::boolean: {
|
||||
return convertPrecision<ngraph::float16, char>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
switch (toPrecision) {
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecision<ngraph::bfloat16, uint8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecision<ngraph::bfloat16, uint16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecision<ngraph::bfloat16, int8_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecision<ngraph::bfloat16, int16_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecision<ngraph::bfloat16, int32_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecision<ngraph::bfloat16, int64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecision<ngraph::bfloat16, float>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecision<ngraph::bfloat16, ngraph::float16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecision<ngraph::bfloat16, ngraph::bfloat16>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecision<ngraph::bfloat16, uint64_t>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
case element::Type_t::boolean: {
|
||||
return convertPrecision<ngraph::bfloat16, char>(output, elementsCount, element::Type(toPrecision).size());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " to: " +
|
||||
element::Type(toPrecision).get_type_name());
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("convertOutputPrecision can't convert from: " + element::Type(fromPrecision).get_type_name() + " precision");
|
||||
case element::Type_t::boolean: {
|
||||
return convertPrecisionFrom<element::Type_t::boolean>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::bf16: {
|
||||
return convertPrecisionFrom<element::Type_t::bf16>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::f16: {
|
||||
return convertPrecisionFrom<element::Type_t::f16>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::f32: {
|
||||
return convertPrecisionFrom<element::Type_t::f32>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::f64: {
|
||||
return convertPrecisionFrom<element::Type_t::f64>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::i8: {
|
||||
return convertPrecisionFrom<element::Type_t::i8>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::i16: {
|
||||
return convertPrecisionFrom<element::Type_t::i16>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::i32: {
|
||||
return convertPrecisionFrom<element::Type_t::i32>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
return convertPrecisionFrom<element::Type_t::i64>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::u8: {
|
||||
return convertPrecisionFrom<element::Type_t::u8>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::u16: {
|
||||
return convertPrecisionFrom<element::Type_t::u16>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::u32: {
|
||||
return convertPrecisionFrom<element::Type_t::u32>(output, toPrecision, elementsCount);
|
||||
}
|
||||
case element::Type_t::u64: {
|
||||
return convertPrecisionFrom<element::Type_t::u64>(output, toPrecision, elementsCount);
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
std::string("convertOutputPrecision can't convert from: ") + element::Type(fromPrecision).get_type_name() +
|
||||
" precision");
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user