Use std::transform instead of vector constructor (#3151)

This commit is contained in:
Tomasz Socha 2020-11-17 04:57:00 +01:00 committed by GitHub
parent e9de4daee7
commit 0f539cc71c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -351,80 +351,67 @@ namespace ngraph
{ {
case element::Type_t::boolean: case element::Type_t::boolean:
{ {
auto vector = get_vector<char>(); cast_vector<char>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::bf16: case element::Type_t::bf16:
{ {
auto vector = get_vector<bfloat16>(); cast_vector<bfloat16>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::f16: case element::Type_t::f16:
{ {
auto vector = get_vector<float16>(); cast_vector<float16>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::f32: case element::Type_t::f32:
{ {
auto vector = get_vector<float>(); cast_vector<float>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::f64: case element::Type_t::f64:
{ {
auto vector = get_vector<double>(); cast_vector<double>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::i8: case element::Type_t::i8:
{ {
auto vector = get_vector<int8_t>(); cast_vector<int8_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::i16: case element::Type_t::i16:
{ {
auto vector = get_vector<int16_t>(); cast_vector<int16_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::i32: case element::Type_t::i32:
{ {
auto vector = get_vector<int32_t>(); cast_vector<int32_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::i64: case element::Type_t::i64:
{ {
auto vector = get_vector<int64_t>(); cast_vector<int64_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::u8: case element::Type_t::u8:
{ {
auto vector = get_vector<uint8_t>(); cast_vector<uint8_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::u16: case element::Type_t::u16:
{ {
auto vector = get_vector<uint16_t>(); cast_vector<uint16_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::u32: case element::Type_t::u32:
{ {
auto vector = get_vector<uint32_t>(); cast_vector<uint32_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
case element::Type_t::u64: case element::Type_t::u64:
{ {
auto vector = get_vector<uint64_t>(); cast_vector<uint64_t>(rc);
rc = std::vector<T>(vector.begin(), vector.end());
break; break;
} }
default: throw std::runtime_error("unsupported type"); default: throw std::runtime_error("unsupported type");
@ -463,6 +450,18 @@ namespace ngraph
std::string convert_value_to_string(size_t index) const; std::string convert_value_to_string(size_t index) const;
protected: protected:
template <typename IN_T, typename OUT_T>
void cast_vector(std::vector<OUT_T>& output_vector) const
{
auto source_vector = get_vector<IN_T>();
output_vector.reserve(source_vector.size());
std::transform(source_vector.begin(),
source_vector.end(),
std::back_inserter(output_vector),
[](IN_T c) { return static_cast<OUT_T>(c); });
}
/// \brief Allocate a buffer and return a pointer to it /// \brief Allocate a buffer and return a pointer to it
void* allocate_buffer(); void* allocate_buffer();