From 75f9242cb4c5befd0469fb02890362721e7b4473 Mon Sep 17 00:00:00 2001 From: Patryk Elszkowski Date: Tue, 13 Apr 2021 14:32:29 +0200 Subject: [PATCH] cleanup constant op class (#5186) Co-authored-by: Patryk Elszkowski --- ngraph/core/include/ngraph/op/constant.hpp | 229 ++++-------------- .../ngraph/type/element_type_traits.hpp | 3 + 2 files changed, 53 insertions(+), 179 deletions(-) diff --git a/ngraph/core/include/ngraph/op/constant.hpp b/ngraph/core/include/ngraph/op/constant.hpp index 3cd0d3ac09f..860a6979965 100644 --- a/ngraph/core/include/ngraph/op/constant.hpp +++ b/ngraph/core/include/ngraph/op/constant.hpp @@ -71,7 +71,7 @@ namespace ngraph m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical(); } - /// \brief Create unitialized constant + /// \brief Create uninitialized constant Constant(const element::Type& type, const Shape& shape); /// \brief Constructs a uniform tensor constant. /// @@ -84,7 +84,7 @@ namespace ngraph Constant(const element::Type& type, const Shape& shape, T value) : Constant(type, shape) { - auto size = shape_size(m_shape); + using Type_t = element::Type_t; #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #pragma GCC diagnostic push #pragma GCC diagnostic error "-Wswitch" @@ -92,115 +92,24 @@ namespace ngraph #endif switch (type) { - case element::Type_t::boolean: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::bf16: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::f16: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::f32: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::f64: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::i8: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::i16: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::i32: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::i64: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::u8: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::u16: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::u32: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::u64: - std::fill_n( - get_data_ptr_nc(), - size, - static_cast< - typename element_type_traits::value_type>( - value)); - break; - case element::Type_t::i4: - case element::Type_t::u1: - case element::Type_t::u4: - case element::Type_t::undefined: - case element::Type_t::dynamic: throw std::runtime_error("unsupported type"); + case Type_t::boolean: fill_data(value); break; + case Type_t::bf16: fill_data(value); break; + case Type_t::f16: fill_data(value); break; + case Type_t::f32: fill_data(value); break; + case Type_t::f64: fill_data(value); break; + case Type_t::i8: fill_data(value); break; + case Type_t::i16: fill_data(value); break; + case Type_t::i32: fill_data(value); break; + case Type_t::i64: fill_data(value); break; + case Type_t::u8: fill_data(value); break; + case Type_t::u16: fill_data(value); break; + case Type_t::u32: fill_data(value); break; + case Type_t::u64: fill_data(value); break; + case Type_t::i4: + case Type_t::u1: + case Type_t::u4: + case Type_t::undefined: + case Type_t::dynamic: throw std::runtime_error("unsupported type"); } #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #pragma GCC diagnostic pop @@ -355,78 +264,26 @@ namespace ngraph { auto source_type = get_element_type(); std::vector rc; - + using Type_t = element::Type_t; #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4244) #endif switch (source_type) { - case element::Type_t::boolean: - { - cast_vector(rc); - break; - } - case element::Type_t::bf16: - { - cast_vector(rc); - break; - } - case element::Type_t::f16: - { - cast_vector(rc); - break; - } - case element::Type_t::f32: - { - cast_vector(rc); - break; - } - case element::Type_t::f64: - { - cast_vector(rc); - break; - } - case element::Type_t::i8: - { - cast_vector(rc); - break; - } - case element::Type_t::i16: - { - cast_vector(rc); - break; - } - case element::Type_t::i32: - { - cast_vector(rc); - break; - } - case element::Type_t::i64: - { - cast_vector(rc); - break; - } - case element::Type_t::u8: - { - cast_vector(rc); - break; - } - case element::Type_t::u16: - { - cast_vector(rc); - break; - } - case element::Type_t::u32: - { - cast_vector(rc); - break; - } - case element::Type_t::u64: - { - cast_vector(rc); - break; - } + case Type_t::boolean: cast_vector(rc); break; + case Type_t::bf16: cast_vector(rc); break; + case Type_t::f16: cast_vector(rc); break; + case Type_t::f32: cast_vector(rc); break; + case Type_t::f64: cast_vector(rc); break; + case Type_t::i8: cast_vector(rc); break; + case Type_t::i16: cast_vector(rc); break; + case Type_t::i32: cast_vector(rc); break; + case Type_t::i64: cast_vector(rc); break; + case Type_t::u8: cast_vector(rc); break; + case Type_t::u16: cast_vector(rc); break; + case Type_t::u32: cast_vector(rc); break; + case Type_t::u64: cast_vector(rc); break; default: throw std::runtime_error("unsupported type"); } #if defined(_MSC_VER) @@ -471,9 +328,13 @@ namespace ngraph } protected: - template + template void cast_vector(std::vector& output_vector) const { + // this function is workaround for waring during windows building + // build complains for vector creation based on iterators + // which point on different type than destination vector::value_type + using IN_T = fundamental_type_for; auto source_vector = get_vector(); output_vector.reserve(source_vector.size()); @@ -483,6 +344,15 @@ namespace ngraph [](IN_T c) { return static_cast(c); }); } + template > + void fill_data(const T& value) + { + const auto size = shape_size(m_shape); + std::fill_n(get_data_ptr_nc(), size, static_cast(value)); + } + void allocate_buffer(); void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); } @@ -587,12 +457,13 @@ namespace ngraph #endif } + bool are_all_data_elements_bitwise_identical() const; static constexpr size_t host_alignment() { return 64; } + element::Type m_element_type; Shape m_shape{}; std::shared_ptr m_data; bool m_all_elements_bitwise_identical; - bool are_all_data_elements_bitwise_identical() const; bool m_alloc_buffer_on_visit_attributes = true; }; } diff --git a/ngraph/core/include/ngraph/type/element_type_traits.hpp b/ngraph/core/include/ngraph/type/element_type_traits.hpp index 8954f97172b..d24b3c7df41 100644 --- a/ngraph/core/include/ngraph/type/element_type_traits.hpp +++ b/ngraph/core/include/ngraph/type/element_type_traits.hpp @@ -13,6 +13,9 @@ namespace ngraph { }; + template + using fundamental_type_for = typename element_type_traits::value_type; + template <> struct element_type_traits {