diff --git a/ngraph/core/include/ngraph/op/constant.hpp b/ngraph/core/include/ngraph/op/constant.hpp index 243b7bdfada..e07ad6cd7c1 100644 --- a/ngraph/core/include/ngraph/op/constant.hpp +++ b/ngraph/core/include/ngraph/op/constant.hpp @@ -351,80 +351,67 @@ namespace ngraph { case element::Type_t::boolean: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::bf16: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::f16: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::f32: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::f64: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::i8: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::i16: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::i32: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::i64: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::u8: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::u16: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::u32: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } case element::Type_t::u64: { - auto vector = get_vector(); - rc = std::vector(vector.begin(), vector.end()); + cast_vector(rc); break; } default: throw std::runtime_error("unsupported type"); @@ -463,6 +450,18 @@ namespace ngraph std::string convert_value_to_string(size_t index) const; protected: + template + void cast_vector(std::vector& output_vector) const + { + auto source_vector = get_vector(); + output_vector.reserve(source_vector.size()); + + std::transform(source_vector.begin(), + source_vector.end(), + std::back_inserter(output_vector), + [](IN_T c) { return static_cast(c); }); + } + /// \brief Allocate a buffer and return a pointer to it void* allocate_buffer();