cleanup constant op class (#5186)
Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
parent
11990e50aa
commit
75f9242cb4
@ -71,7 +71,7 @@ namespace ngraph
|
||||
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
|
||||
}
|
||||
|
||||
/// \brief Create unitialized constant
|
||||
/// \brief Create uninitialized constant
|
||||
Constant(const element::Type& type, const Shape& shape);
|
||||
/// \brief Constructs a uniform tensor constant.
|
||||
///
|
||||
@ -84,7 +84,7 @@ namespace ngraph
|
||||
Constant(const element::Type& type, const Shape& shape, T value)
|
||||
: Constant(type, shape)
|
||||
{
|
||||
auto size = shape_size(m_shape);
|
||||
using Type_t = element::Type_t;
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic error "-Wswitch"
|
||||
@ -92,115 +92,24 @@ namespace ngraph
|
||||
#endif
|
||||
switch (type)
|
||||
{
|
||||
case element::Type_t::boolean:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::boolean>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::boolean>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::bf16:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::bf16>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::bf16>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::f16:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::f16>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::f16>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::f32>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::f32>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::f64:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::f64>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::f64>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::i8>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::i8>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::i16:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::i16>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::i16>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::i32>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::i32>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::i64>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::i64>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::u8>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::u8>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::u16:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::u16>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::u16>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::u32>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::u32>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
std::fill_n(
|
||||
get_data_ptr_nc<element::Type_t::u64>(),
|
||||
size,
|
||||
static_cast<
|
||||
typename element_type_traits<element::Type_t::u64>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
|
||||
case Type_t::boolean: fill_data<Type_t::boolean>(value); break;
|
||||
case Type_t::bf16: fill_data<Type_t::bf16>(value); break;
|
||||
case Type_t::f16: fill_data<Type_t::f16>(value); break;
|
||||
case Type_t::f32: fill_data<Type_t::f32>(value); break;
|
||||
case Type_t::f64: fill_data<Type_t::f64>(value); break;
|
||||
case Type_t::i8: fill_data<Type_t::i8>(value); break;
|
||||
case Type_t::i16: fill_data<Type_t::i16>(value); break;
|
||||
case Type_t::i32: fill_data<Type_t::i32>(value); break;
|
||||
case Type_t::i64: fill_data<Type_t::i64>(value); break;
|
||||
case Type_t::u8: fill_data<Type_t::u8>(value); break;
|
||||
case Type_t::u16: fill_data<Type_t::u16>(value); break;
|
||||
case Type_t::u32: fill_data<Type_t::u32>(value); break;
|
||||
case Type_t::u64: fill_data<Type_t::u64>(value); break;
|
||||
case Type_t::i4:
|
||||
case Type_t::u1:
|
||||
case Type_t::u4:
|
||||
case Type_t::undefined:
|
||||
case Type_t::dynamic: throw std::runtime_error("unsupported type");
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
#pragma GCC diagnostic pop
|
||||
@ -355,78 +264,26 @@ namespace ngraph
|
||||
{
|
||||
auto source_type = get_element_type();
|
||||
std::vector<T> rc;
|
||||
|
||||
using Type_t = element::Type_t;
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4244)
|
||||
#endif
|
||||
switch (source_type)
|
||||
{
|
||||
case element::Type_t::boolean:
|
||||
{
|
||||
cast_vector<char>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::bf16:
|
||||
{
|
||||
cast_vector<bfloat16>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::f16:
|
||||
{
|
||||
cast_vector<float16>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::f32:
|
||||
{
|
||||
cast_vector<float>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::f64:
|
||||
{
|
||||
cast_vector<double>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i8:
|
||||
{
|
||||
cast_vector<int8_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i16:
|
||||
{
|
||||
cast_vector<int16_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i32:
|
||||
{
|
||||
cast_vector<int32_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i64:
|
||||
{
|
||||
cast_vector<int64_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u8:
|
||||
{
|
||||
cast_vector<uint8_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u16:
|
||||
{
|
||||
cast_vector<uint16_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u32:
|
||||
{
|
||||
cast_vector<uint32_t>(rc);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u64:
|
||||
{
|
||||
cast_vector<uint64_t>(rc);
|
||||
break;
|
||||
}
|
||||
case Type_t::boolean: cast_vector<Type_t::boolean>(rc); break;
|
||||
case Type_t::bf16: cast_vector<Type_t::bf16>(rc); break;
|
||||
case Type_t::f16: cast_vector<Type_t::f16>(rc); break;
|
||||
case Type_t::f32: cast_vector<Type_t::f32>(rc); break;
|
||||
case Type_t::f64: cast_vector<Type_t::f64>(rc); break;
|
||||
case Type_t::i8: cast_vector<Type_t::i8>(rc); break;
|
||||
case Type_t::i16: cast_vector<Type_t::i16>(rc); break;
|
||||
case Type_t::i32: cast_vector<Type_t::i32>(rc); break;
|
||||
case Type_t::i64: cast_vector<Type_t::i64>(rc); break;
|
||||
case Type_t::u8: cast_vector<Type_t::u8>(rc); break;
|
||||
case Type_t::u16: cast_vector<Type_t::u16>(rc); break;
|
||||
case Type_t::u32: cast_vector<Type_t::u32>(rc); break;
|
||||
case Type_t::u64: cast_vector<Type_t::u64>(rc); break;
|
||||
default: throw std::runtime_error("unsupported type");
|
||||
}
|
||||
#if defined(_MSC_VER)
|
||||
@ -471,9 +328,13 @@ namespace ngraph
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename IN_T, typename OUT_T>
|
||||
template <element::Type_t Type, typename OUT_T>
|
||||
void cast_vector(std::vector<OUT_T>& output_vector) const
|
||||
{
|
||||
// this function is workaround for waring during windows building
|
||||
// build complains for vector creation based on iterators
|
||||
// which point on different type than destination vector::value_type
|
||||
using IN_T = fundamental_type_for<Type>;
|
||||
auto source_vector = get_vector<IN_T>();
|
||||
output_vector.reserve(source_vector.size());
|
||||
|
||||
@ -483,6 +344,15 @@ namespace ngraph
|
||||
[](IN_T c) { return static_cast<OUT_T>(c); });
|
||||
}
|
||||
|
||||
template <element::Type_t Type,
|
||||
typename T,
|
||||
typename DataStorageType = fundamental_type_for<Type>>
|
||||
void fill_data(const T& value)
|
||||
{
|
||||
const auto size = shape_size(m_shape);
|
||||
std::fill_n(get_data_ptr_nc<Type>(), size, static_cast<DataStorageType>(value));
|
||||
}
|
||||
|
||||
void allocate_buffer();
|
||||
|
||||
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
|
||||
@ -587,12 +457,13 @@ namespace ngraph
|
||||
#endif
|
||||
}
|
||||
|
||||
bool are_all_data_elements_bitwise_identical() const;
|
||||
static constexpr size_t host_alignment() { return 64; }
|
||||
|
||||
element::Type m_element_type;
|
||||
Shape m_shape{};
|
||||
std::shared_ptr<runtime::AlignedBuffer> m_data;
|
||||
bool m_all_elements_bitwise_identical;
|
||||
bool are_all_data_elements_bitwise_identical() const;
|
||||
bool m_alloc_buffer_on_visit_attributes = true;
|
||||
};
|
||||
}
|
||||
|
@ -13,6 +13,9 @@ namespace ngraph
|
||||
{
|
||||
};
|
||||
|
||||
template <element::Type_t Type>
|
||||
using fundamental_type_for = typename element_type_traits<Type>::value_type;
|
||||
|
||||
template <>
|
||||
struct element_type_traits<element::Type_t::boolean>
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user