[ONNX] Avoid allocating vector for constants if possible (#10860)

For FLOAT, DOUBLE, INT32, INT64, UINT64 we can get a pointer
to data from TensorProto and pass it to Constant constructor.
This commit is contained in:
Mateusz Tabaka
2022-03-20 14:07:50 +01:00
committed by GitHub
parent 5390aa7ebc
commit c18030207c
3 changed files with 210 additions and 208 deletions

View File

@@ -131,6 +131,29 @@ graph {
dims: 3
data_type: 1
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
float_data: 4.0
}
type: TENSOR
}
@@ -146,6 +169,21 @@ graph {
dims: 2
data_type: 1
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
}
type: TENSOR
}
@@ -160,6 +198,21 @@ graph {
dims: 16
data_type: 1
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
float_data: 3.0
}
type: TENSOR
}

View File

@@ -51,13 +51,15 @@ struct data_type_undefined : ngraph_error {
struct segments_unsupported : ngraph_error {
segments_unsupported() : ngraph_error{"loading segments not supported"} {}
};
struct shape_doesnt_match_data_size : ngraph_error {
shape_doesnt_match_data_size() : ngraph_error{"tensor shape doesn't match data size"} {}
};
} // namespace tensor
} // namespace error
namespace detail {
namespace tensor {
namespace {
namespace detail {
template <typename T, typename Container>
inline std::vector<T> __get_data(const Container& container) {
#if defined(_MSC_VER)
@@ -70,6 +72,16 @@ inline std::vector<T> __get_data(const Container& container) {
#endif
}
bool has_tensor_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
return tensor.has_data_location() &&
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL;
}
inline std::string load_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
const auto tensor_external_data = TensorExternalData(tensor);
return tensor_external_data.load_external_data();
}
template <typename T>
inline std::vector<T> __get_raw_data(const std::string& raw_data, int onnx_data_type) {
auto it = reinterpret_cast<const T*>(raw_data.data());
@@ -78,22 +90,46 @@ inline std::vector<T> __get_raw_data(const std::string& raw_data, int onnx_data_
template <typename T>
inline std::vector<T> get_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
const auto tensor_external_data = TensorExternalData(tensor);
const auto raw_data = tensor_external_data.load_external_data();
return detail::__get_raw_data<T>(raw_data, tensor.data_type());
return __get_raw_data<T>(load_external_data(tensor), tensor.data_type());
}
bool has_tensor_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (tensor.has_data_location() &&
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) {
return true;
} else {
return false;
inline const void* get_data_ptr(const ONNX_NAMESPACE::TensorProto& tensor) {
if (tensor.has_raw_data()) {
return tensor.raw_data().data();
}
switch (tensor.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return tensor.float_data().data();
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
return tensor.int32_data().data();
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
return tensor.int64_data().data();
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
return tensor.uint64_data().data();
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
return tensor.double_data().data();
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
inline size_t get_data_size(const ONNX_NAMESPACE::TensorProto& tensor) {
if (tensor.has_raw_data()) {
return tensor.raw_data().size() / onnx_common::get_onnx_data_size(tensor.data_type());
}
switch (tensor.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return tensor.float_data_size();
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
return tensor.int32_data_size();
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
return tensor.int64_data_size();
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
return tensor.uint64_data_size();
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
return tensor.double_data_size();
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
} // namespace detail
} // namespace
template <typename T>
inline std::vector<T> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
@@ -102,60 +138,39 @@ inline std::vector<T> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
template <>
inline std::vector<double> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<double>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<double>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<double>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<double>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
return detail::__get_data<double>(tensor.double_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return detail::__get_data<double>(tensor.float_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
return detail::__get_data<double>(tensor.int32_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return detail::__get_data<double>(tensor.int64_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
return detail::__get_data<double>(tensor.uint64_data());
return __get_data<double>(tensor.double_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<float> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<float>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<float>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<float>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<float>(tensor.raw_data(), tensor.data_type());
}
if ((tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return detail::__get_data<float>(tensor.float_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
return detail::__get_data<float>(tensor.int32_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return detail::__get_data<float>(tensor.int64_data());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
return detail::__get_data<float>(tensor.uint64_data());
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return __get_data<float>(tensor.float_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<ngraph::float16> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<float16>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<float16>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<ngraph::float16>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<ngraph::float16>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
using std::begin;
@@ -168,153 +183,153 @@ inline std::vector<ngraph::float16> get_data(const ONNX_NAMESPACE::TensorProto&
return ngraph::float16::from_bits(static_cast<uint16_t>(elem));
});
return detail::__get_data<ngraph::float16>(float16_data);
return __get_data<ngraph::float16>(float16_data);
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<ngraph::bfloat16> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<bfloat16>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<bfloat16>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<ngraph::bfloat16>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<ngraph::bfloat16>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
return detail::__get_data<ngraph::bfloat16>(tensor.int32_data());
return __get_data<ngraph::bfloat16>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int8_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<int8_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<int8_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<int8_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<int8_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
return detail::__get_data<int8_t>(tensor.int32_data());
return __get_data<int8_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int16_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<int16_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<int16_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<int16_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<int16_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
return detail::__get_data<int16_t>(tensor.int32_data());
return __get_data<int16_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int32_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<int32_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<int32_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<int32_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<int32_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
return detail::__get_data<int32_t>(tensor.int32_data());
return __get_data<int32_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int64_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<int64_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<int64_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<int64_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<int64_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
throw error::tensor::invalid_data_type{tensor.data_type()};
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return __get_data<int64_t>(tensor.int64_data());
}
return detail::__get_data<int64_t>(tensor.int64_data());
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint8_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<uint8_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<uint8_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<uint8_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<uint8_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
return detail::__get_data<uint8_t>(tensor.int32_data());
return __get_data<uint8_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint16_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<uint16_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<uint16_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<uint16_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<uint16_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
return detail::__get_data<uint16_t>(tensor.int32_data());
return __get_data<uint16_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint32_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<uint32_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<uint32_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<uint32_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<uint32_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT32) {
return detail::__get_data<uint32_t>(tensor.uint64_data());
return __get_data<uint32_t>(tensor.uint64_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint64_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<uint64_t>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<uint64_t>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<uint64_t>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<uint64_t>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
throw error::tensor::invalid_data_type{tensor.data_type()};
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
return __get_data<uint64_t>(tensor.uint64_data());
}
return detail::__get_data<uint64_t>(tensor.uint64_data());
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<char> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
// Boolean values are stored as char because std::vector<bool>
// can behave differently from other vector containers.
if (detail::has_tensor_external_data(tensor)) {
return detail::get_external_data<char>(tensor);
if (has_tensor_external_data(tensor)) {
return get_external_data<char>(tensor);
}
if (tensor.has_raw_data()) {
return detail::__get_raw_data<char>(tensor.raw_data(), tensor.data_type());
return __get_raw_data<char>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
return detail::__get_data<char>(tensor.int32_data());
return __get_data<char>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
} // namespace tensor
} // namespace
} // namespace detail
class Tensor {
@@ -365,7 +380,7 @@ public:
if (m_tensor_proto->has_segment()) {
throw error::tensor::segments_unsupported{};
}
return detail::tensor::get_data<T>(*m_tensor_proto);
return detail::get_data<T>(*m_tensor_proto);
}
const std::string& get_name() const {
@@ -423,7 +438,11 @@ public:
operator TensorProto_DataType() const {
return m_tensor_proto->data_type();
}
std::shared_ptr<ngraph::op::Constant> get_ng_constant() const {
if (m_tensor_proto->has_segment()) {
throw error::tensor::segments_unsupported{};
}
switch (m_tensor_proto->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<char>(element::boolean);
@@ -457,9 +476,47 @@ public:
}
private:
template <typename T>
template <typename T,
typename std::enable_if<std::is_same<T, float>::value || std::is_same<T, double>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
std::is_same<T, uint64_t>::value,
bool>::type = true>
std::shared_ptr<ngraph::op::Constant> make_ng_constant(const element::Type& type) const {
auto constant = std::make_shared<ngraph::op::Constant>(type, m_shape, get_data<T>());
std::shared_ptr<default_opset::Constant> constant{nullptr};
int data_size = detail::get_data_size(*m_tensor_proto);
if (detail::has_tensor_external_data(*m_tensor_proto)) {
auto external_data = detail::load_external_data(*m_tensor_proto);
constant = std::make_shared<ngraph::op::Constant>(type, m_shape, external_data.data());
} else if (data_size == shape_size(m_shape)) {
constant = std::make_shared<ngraph::op::Constant>(type, m_shape, detail::get_data_ptr(*m_tensor_proto));
} else if (data_size == 0 && m_shape.size() == 0) {
constant = common::make_failsafe_constant(type);
} else {
throw error::tensor::shape_doesnt_match_data_size{};
}
if (m_tensor_proto->has_name()) {
constant->set_friendly_name(get_name());
}
return constant;
}
template <typename T,
typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, double>::value &&
!std::is_same<T, int32_t>::value && !std::is_same<T, int64_t>::value &&
!std::is_same<T, uint64_t>::value,
bool>::type = true>
std::shared_ptr<ngraph::op::Constant> make_ng_constant(const element::Type& type) const {
std::shared_ptr<default_opset::Constant> constant{nullptr};
auto data = get_data<T>();
auto data_size = data.size();
if (data_size == shape_size(m_shape)) {
constant = std::make_shared<ngraph::op::Constant>(type, m_shape, data);
} else if (data_size == 0 && m_shape.size() == 0) {
constant = common::make_failsafe_constant(type);
} else {
throw error::tensor::shape_doesnt_match_data_size{};
}
if (m_tensor_proto->has_name()) {
constant->set_friendly_name(get_name());
}

View File

@@ -18,112 +18,6 @@ namespace ngraph {
namespace onnx_import {
namespace op {
namespace {
template <typename T>
inline std::shared_ptr<default_opset::Constant> make_ng_constant_impl(const element::Type& type, const Tensor& tensor) {
std::shared_ptr<default_opset::Constant> constant{nullptr};
try {
constant = std::make_shared<default_opset::Constant>(type, tensor.get_shape(), tensor.get_data<T>());
} catch (const ngraph::ngraph_error&) {
constant = common::make_failsafe_constant(type);
}
return constant;
}
template <Tensor::Type>
inline std::shared_ptr<default_opset::Constant> make_ng_constant(const Tensor& tensor) {
throw error::tensor::unsupported_data_type{tensor};
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::float16>(const Tensor& tensor) {
return make_ng_constant_impl<ngraph::float16>(element::f16, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::float32>(const Tensor& tensor) {
return make_ng_constant_impl<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::float64>(const Tensor& tensor) {
return make_ng_constant_impl<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::int8>(const Tensor& tensor) {
return make_ng_constant_impl<int8_t>(element::i8, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::int16>(const Tensor& tensor) {
return make_ng_constant_impl<int16_t>(element::i16, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::int32>(const Tensor& tensor) {
return make_ng_constant_impl<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::int64>(const Tensor& tensor) {
return make_ng_constant_impl<int64_t>(element::i64, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::uint8>(const Tensor& tensor) {
return make_ng_constant_impl<uint8_t>(element::u8, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::uint16>(const Tensor& tensor) {
return make_ng_constant_impl<uint16_t>(element::u16, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor) {
return make_ng_constant_impl<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor) {
return make_ng_constant_impl<uint64_t>(element::u64, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::boolean>(const Tensor& tensor) {
return make_ng_constant_impl<char>(element::boolean, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant> make_ng_constant<Tensor::Type::bfloat16>(const Tensor& tensor) {
return make_ng_constant_impl<ngraph::bfloat16>(element::bf16, tensor);
}
inline std::shared_ptr<default_opset::Constant> make_constant(const Tensor& tensor) {
#define MAKE_NG_CONSTANT(data_type_) \
case data_type_: \
return make_ng_constant<data_type_>(tensor)
switch (tensor.get_type()) {
MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int8);
MAKE_NG_CONSTANT(Tensor::Type::int16);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint8);
MAKE_NG_CONSTANT(Tensor::Type::uint16);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
MAKE_NG_CONSTANT(Tensor::Type::boolean);
MAKE_NG_CONSTANT(Tensor::Type::bfloat16);
default:
throw error::tensor::invalid_data_type{tensor};
}
}
template <typename T>
std::vector<T> get_dense_vector(const std::vector<T>& values, const std::vector<int64_t>& indices, const size_t size) {
NGRAPH_CHECK(values.size() == indices.size(),
@@ -212,7 +106,8 @@ std::vector<int64_t> get_absolute_indices(const Tensor& indices_tensor, const Sh
namespace set_1 {
OutputVector constant(const onnx_import::Node& node) {
return {make_constant(node.get_attribute_value<Tensor>("value"))};
auto tensor = node.get_attribute_value<Tensor>("value");
return {tensor.get_ng_constant()};
}
} // namespace set_1
@@ -283,13 +178,10 @@ OutputVector constant(const onnx_import::Node& node) {
}
return {get_dense_tensor_as_constant(absolute_indices, values_tensor, shape)};
}
return {make_constant(node.get_attribute_value<Tensor>(attributes_names[0]))};
auto tensor = node.get_attribute_value<Tensor>(attributes_names[0]);
return {tensor.get_ng_constant()};
}
} // namespace set_13
} // namespace op
} // namespace onnx_import
} // namespace ngraph