cleanup constant op class (#5186)

Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
Patryk Elszkowski 2021-04-13 14:32:29 +02:00 committed by GitHub
parent 11990e50aa
commit 75f9242cb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 179 deletions

View File

@ -71,7 +71,7 @@ namespace ngraph
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
/// \brief Create unitialized constant
/// \brief Create uninitialized constant
Constant(const element::Type& type, const Shape& shape);
/// \brief Constructs a uniform tensor constant.
///
@ -84,7 +84,7 @@ namespace ngraph
Constant(const element::Type& type, const Shape& shape, T value)
: Constant(type, shape)
{
auto size = shape_size(m_shape);
using Type_t = element::Type_t;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
@ -92,115 +92,24 @@ namespace ngraph
#endif
switch (type)
{
case element::Type_t::boolean:
std::fill_n(
get_data_ptr_nc<element::Type_t::boolean>(),
size,
static_cast<
typename element_type_traits<element::Type_t::boolean>::value_type>(
value));
break;
case element::Type_t::bf16:
std::fill_n(
get_data_ptr_nc<element::Type_t::bf16>(),
size,
static_cast<
typename element_type_traits<element::Type_t::bf16>::value_type>(
value));
break;
case element::Type_t::f16:
std::fill_n(
get_data_ptr_nc<element::Type_t::f16>(),
size,
static_cast<
typename element_type_traits<element::Type_t::f16>::value_type>(
value));
break;
case element::Type_t::f32:
std::fill_n(
get_data_ptr_nc<element::Type_t::f32>(),
size,
static_cast<
typename element_type_traits<element::Type_t::f32>::value_type>(
value));
break;
case element::Type_t::f64:
std::fill_n(
get_data_ptr_nc<element::Type_t::f64>(),
size,
static_cast<
typename element_type_traits<element::Type_t::f64>::value_type>(
value));
break;
case element::Type_t::i8:
std::fill_n(
get_data_ptr_nc<element::Type_t::i8>(),
size,
static_cast<
typename element_type_traits<element::Type_t::i8>::value_type>(
value));
break;
case element::Type_t::i16:
std::fill_n(
get_data_ptr_nc<element::Type_t::i16>(),
size,
static_cast<
typename element_type_traits<element::Type_t::i16>::value_type>(
value));
break;
case element::Type_t::i32:
std::fill_n(
get_data_ptr_nc<element::Type_t::i32>(),
size,
static_cast<
typename element_type_traits<element::Type_t::i32>::value_type>(
value));
break;
case element::Type_t::i64:
std::fill_n(
get_data_ptr_nc<element::Type_t::i64>(),
size,
static_cast<
typename element_type_traits<element::Type_t::i64>::value_type>(
value));
break;
case element::Type_t::u8:
std::fill_n(
get_data_ptr_nc<element::Type_t::u8>(),
size,
static_cast<
typename element_type_traits<element::Type_t::u8>::value_type>(
value));
break;
case element::Type_t::u16:
std::fill_n(
get_data_ptr_nc<element::Type_t::u16>(),
size,
static_cast<
typename element_type_traits<element::Type_t::u16>::value_type>(
value));
break;
case element::Type_t::u32:
std::fill_n(
get_data_ptr_nc<element::Type_t::u32>(),
size,
static_cast<
typename element_type_traits<element::Type_t::u32>::value_type>(
value));
break;
case element::Type_t::u64:
std::fill_n(
get_data_ptr_nc<element::Type_t::u64>(),
size,
static_cast<
typename element_type_traits<element::Type_t::u64>::value_type>(
value));
break;
case element::Type_t::i4:
case element::Type_t::u1:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
case Type_t::boolean: fill_data<Type_t::boolean>(value); break;
case Type_t::bf16: fill_data<Type_t::bf16>(value); break;
case Type_t::f16: fill_data<Type_t::f16>(value); break;
case Type_t::f32: fill_data<Type_t::f32>(value); break;
case Type_t::f64: fill_data<Type_t::f64>(value); break;
case Type_t::i8: fill_data<Type_t::i8>(value); break;
case Type_t::i16: fill_data<Type_t::i16>(value); break;
case Type_t::i32: fill_data<Type_t::i32>(value); break;
case Type_t::i64: fill_data<Type_t::i64>(value); break;
case Type_t::u8: fill_data<Type_t::u8>(value); break;
case Type_t::u16: fill_data<Type_t::u16>(value); break;
case Type_t::u32: fill_data<Type_t::u32>(value); break;
case Type_t::u64: fill_data<Type_t::u64>(value); break;
case Type_t::i4:
case Type_t::u1:
case Type_t::u4:
case Type_t::undefined:
case Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
@ -355,78 +264,26 @@ namespace ngraph
{
auto source_type = get_element_type();
std::vector<T> rc;
using Type_t = element::Type_t;
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4244)
#endif
switch (source_type)
{
case element::Type_t::boolean:
{
cast_vector<char>(rc);
break;
}
case element::Type_t::bf16:
{
cast_vector<bfloat16>(rc);
break;
}
case element::Type_t::f16:
{
cast_vector<float16>(rc);
break;
}
case element::Type_t::f32:
{
cast_vector<float>(rc);
break;
}
case element::Type_t::f64:
{
cast_vector<double>(rc);
break;
}
case element::Type_t::i8:
{
cast_vector<int8_t>(rc);
break;
}
case element::Type_t::i16:
{
cast_vector<int16_t>(rc);
break;
}
case element::Type_t::i32:
{
cast_vector<int32_t>(rc);
break;
}
case element::Type_t::i64:
{
cast_vector<int64_t>(rc);
break;
}
case element::Type_t::u8:
{
cast_vector<uint8_t>(rc);
break;
}
case element::Type_t::u16:
{
cast_vector<uint16_t>(rc);
break;
}
case element::Type_t::u32:
{
cast_vector<uint32_t>(rc);
break;
}
case element::Type_t::u64:
{
cast_vector<uint64_t>(rc);
break;
}
case Type_t::boolean: cast_vector<Type_t::boolean>(rc); break;
case Type_t::bf16: cast_vector<Type_t::bf16>(rc); break;
case Type_t::f16: cast_vector<Type_t::f16>(rc); break;
case Type_t::f32: cast_vector<Type_t::f32>(rc); break;
case Type_t::f64: cast_vector<Type_t::f64>(rc); break;
case Type_t::i8: cast_vector<Type_t::i8>(rc); break;
case Type_t::i16: cast_vector<Type_t::i16>(rc); break;
case Type_t::i32: cast_vector<Type_t::i32>(rc); break;
case Type_t::i64: cast_vector<Type_t::i64>(rc); break;
case Type_t::u8: cast_vector<Type_t::u8>(rc); break;
case Type_t::u16: cast_vector<Type_t::u16>(rc); break;
case Type_t::u32: cast_vector<Type_t::u32>(rc); break;
case Type_t::u64: cast_vector<Type_t::u64>(rc); break;
default: throw std::runtime_error("unsupported type");
}
#if defined(_MSC_VER)
@ -471,9 +328,13 @@ namespace ngraph
}
protected:
template <typename IN_T, typename OUT_T>
template <element::Type_t Type, typename OUT_T>
void cast_vector(std::vector<OUT_T>& output_vector) const
{
// this function is workaround for waring during windows building
// build complains for vector creation based on iterators
// which point on different type than destination vector::value_type
using IN_T = fundamental_type_for<Type>;
auto source_vector = get_vector<IN_T>();
output_vector.reserve(source_vector.size());
@ -483,6 +344,15 @@ namespace ngraph
[](IN_T c) { return static_cast<OUT_T>(c); });
}
template <element::Type_t Type,
typename T,
typename DataStorageType = fundamental_type_for<Type>>
void fill_data(const T& value)
{
const auto size = shape_size(m_shape);
std::fill_n(get_data_ptr_nc<Type>(), size, static_cast<DataStorageType>(value));
}
void allocate_buffer();
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
@ -587,12 +457,13 @@ namespace ngraph
#endif
}
bool are_all_data_elements_bitwise_identical() const;
static constexpr size_t host_alignment() { return 64; }
element::Type m_element_type;
Shape m_shape{};
std::shared_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const;
bool m_alloc_buffer_on_visit_attributes = true;
};
}

View File

@ -13,6 +13,9 @@ namespace ngraph
{
};
template <element::Type_t Type>
using fundamental_type_for = typename element_type_traits<Type>::value_type;
template <>
struct element_type_traits<element::Type_t::boolean>
{