Use std::transform instead of vector constructor (#3151)

This commit is contained in:
Tomasz Socha 2020-11-17 04:57:00 +01:00 committed by GitHub
parent e9de4daee7
commit 0f539cc71c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -351,80 +351,67 @@ namespace ngraph
{
case element::Type_t::boolean:
{
auto vector = get_vector<char>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<char>(rc);
break;
}
case element::Type_t::bf16:
{
auto vector = get_vector<bfloat16>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<bfloat16>(rc);
break;
}
case element::Type_t::f16:
{
auto vector = get_vector<float16>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<float16>(rc);
break;
}
case element::Type_t::f32:
{
auto vector = get_vector<float>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<float>(rc);
break;
}
case element::Type_t::f64:
{
auto vector = get_vector<double>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<double>(rc);
break;
}
case element::Type_t::i8:
{
auto vector = get_vector<int8_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<int8_t>(rc);
break;
}
case element::Type_t::i16:
{
auto vector = get_vector<int16_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<int16_t>(rc);
break;
}
case element::Type_t::i32:
{
auto vector = get_vector<int32_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<int32_t>(rc);
break;
}
case element::Type_t::i64:
{
auto vector = get_vector<int64_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<int64_t>(rc);
break;
}
case element::Type_t::u8:
{
auto vector = get_vector<uint8_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<uint8_t>(rc);
break;
}
case element::Type_t::u16:
{
auto vector = get_vector<uint16_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<uint16_t>(rc);
break;
}
case element::Type_t::u32:
{
auto vector = get_vector<uint32_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<uint32_t>(rc);
break;
}
case element::Type_t::u64:
{
auto vector = get_vector<uint64_t>();
rc = std::vector<T>(vector.begin(), vector.end());
cast_vector<uint64_t>(rc);
break;
}
default: throw std::runtime_error("unsupported type");
@ -463,6 +450,18 @@ namespace ngraph
std::string convert_value_to_string(size_t index) const;
protected:
template <typename IN_T, typename OUT_T>
void cast_vector(std::vector<OUT_T>& output_vector) const
{
auto source_vector = get_vector<IN_T>();
output_vector.reserve(source_vector.size());
std::transform(source_vector.begin(),
source_vector.end(),
std::back_inserter(output_vector),
[](IN_T c) { return static_cast<OUT_T>(c); });
}
/// \brief Allocate a buffer and return a pointer to it
void* allocate_buffer();