[ONNX] propagate model directory path to Graph, Attribute and Tensor … (#13010)
Required by external data mechanism in ONNX standard, where Tensor object has to be able to find external data file based on the location field and model directory path. Previously, it was done by a transformation, but it handled initializers only, but in order to handle Constant nodes we need one more loop, but over all of the model's nodes. Propagating model directory path to Tensor allows us to reduce that overhead. Ticket: 91271
This commit is contained in:
parent
d2e251a109
commit
a850c8baee
@ -168,46 +168,6 @@ inline std::vector<std::string> get_value(const ONNX_NAMESPACE::AttributeProto&
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Tensor get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
|
||||
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR) {
|
||||
throw error::attribute::InvalidData{attribute.type()};
|
||||
}
|
||||
return Tensor{attribute.t()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<Tensor> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
|
||||
switch (attribute.type()) {
|
||||
case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR:
|
||||
return {Tensor{attribute.t()}};
|
||||
case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS:
|
||||
return {std::begin(attribute.tensors()), std::end(attribute.tensors())};
|
||||
default:
|
||||
throw error::attribute::InvalidData{attribute.type()};
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline SparseTensor get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
|
||||
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR) {
|
||||
throw error::attribute::InvalidData{attribute.type()};
|
||||
}
|
||||
return SparseTensor{attribute.sparse_tensor()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<SparseTensor> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
|
||||
switch (attribute.type()) {
|
||||
case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR:
|
||||
return {SparseTensor{attribute.sparse_tensor()}};
|
||||
case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS:
|
||||
return {std::begin(attribute.sparse_tensors()), std::end(attribute.sparse_tensors())};
|
||||
default:
|
||||
throw error::attribute::InvalidData{attribute.type()};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace attribute
|
||||
|
||||
} // namespace detail
|
||||
@ -231,7 +191,9 @@ public:
|
||||
};
|
||||
|
||||
Attribute() = delete;
|
||||
explicit Attribute(const ONNX_NAMESPACE::AttributeProto& attribute_proto) : m_attribute_proto{&attribute_proto} {}
|
||||
explicit Attribute(const ONNX_NAMESPACE::AttributeProto& attribute_proto, const std::string& model_dir)
|
||||
: m_attribute_proto{&attribute_proto},
|
||||
m_model_dir{model_dir} {}
|
||||
|
||||
Attribute(Attribute&&) noexcept = default;
|
||||
Attribute(const Attribute&) = default;
|
||||
@ -282,10 +244,10 @@ public:
|
||||
return get_type() == Type::graph_array;
|
||||
}
|
||||
Tensor get_tensor() const {
|
||||
return Tensor{m_attribute_proto->t()};
|
||||
return Tensor{m_attribute_proto->t(), m_model_dir};
|
||||
}
|
||||
SparseTensor get_sparse_tensor() const {
|
||||
return SparseTensor{m_attribute_proto->sparse_tensor()};
|
||||
return SparseTensor{m_attribute_proto->sparse_tensor(), m_model_dir};
|
||||
}
|
||||
float get_float() const {
|
||||
return m_attribute_proto->f();
|
||||
@ -299,11 +261,21 @@ public:
|
||||
Subgraph get_subgraph(const Graph* parent_graph) const;
|
||||
|
||||
std::vector<Tensor> get_tensor_array() const {
|
||||
return {std::begin(m_attribute_proto->tensors()), std::end(m_attribute_proto->tensors())};
|
||||
std::vector<Tensor> ret;
|
||||
const auto& tensors = m_attribute_proto->tensors();
|
||||
ret.reserve(tensors.size());
|
||||
for (const auto& tensor : tensors)
|
||||
ret.emplace_back(tensor, m_model_dir);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<SparseTensor> get_sparse_tensor_array() const {
|
||||
return {std::begin(m_attribute_proto->sparse_tensors()), std::end(m_attribute_proto->sparse_tensors())};
|
||||
std::vector<SparseTensor> ret;
|
||||
const auto& sparse_tensors = m_attribute_proto->sparse_tensors();
|
||||
ret.reserve(sparse_tensors.size());
|
||||
for (const auto& tensor : sparse_tensors)
|
||||
ret.emplace_back(tensor, m_model_dir);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<float> get_float_array() const {
|
||||
@ -322,15 +294,56 @@ public:
|
||||
return m_attribute_proto->type();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T,
|
||||
typename std::enable_if<!std::is_same<T, Tensor>::value && !std::is_same<T, std::vector<Tensor>>::value &&
|
||||
!std::is_same<T, SparseTensor>::value &&
|
||||
!std::is_same<T, std::vector<SparseTensor>>::value,
|
||||
bool>::type = true>
|
||||
T get_value() const {
|
||||
return detail::attribute::get_value<T>(*m_attribute_proto);
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, Tensor>::value, bool>::type = true>
|
||||
T get_value() const {
|
||||
if (is_tensor()) {
|
||||
return Tensor{m_attribute_proto->t(), m_model_dir};
|
||||
}
|
||||
throw error::attribute::InvalidData{m_attribute_proto->type()};
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, std::vector<Tensor>>::value, bool>::type = true>
|
||||
T get_value() const {
|
||||
if (is_tensor()) {
|
||||
return {Tensor{m_attribute_proto->t(), m_model_dir}};
|
||||
} else if (is_tensor_array()) {
|
||||
return get_tensor_array();
|
||||
}
|
||||
throw error::attribute::InvalidData{m_attribute_proto->type()};
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, SparseTensor>::value, bool>::type = true>
|
||||
T get_value() const {
|
||||
if (is_sparse_tensor()) {
|
||||
return SparseTensor{m_attribute_proto->sparse_tensor(), m_model_dir};
|
||||
}
|
||||
throw error::attribute::InvalidData{m_attribute_proto->type()};
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, std::vector<SparseTensor>>::value, bool>::type = true>
|
||||
T get_value() const {
|
||||
if (is_sparse_tensor()) {
|
||||
return {SparseTensor{m_attribute_proto->sparse_tensor(), m_model_dir}};
|
||||
} else if (is_sparse_tensor_array()) {
|
||||
return get_sparse_tensor_array();
|
||||
}
|
||||
throw error::attribute::InvalidData{m_attribute_proto->type()};
|
||||
}
|
||||
|
||||
ov::Any get_any() const;
|
||||
|
||||
private:
|
||||
const ONNX_NAMESPACE::AttributeProto* m_attribute_proto;
|
||||
std::string m_model_dir;
|
||||
};
|
||||
|
||||
} // namespace onnx_import
|
||||
|
@ -128,14 +128,18 @@ ov::frontend::ExtensionHolder subgraph_required_extensions(
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
Graph::Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto, ov::frontend::ExtensionHolder extensions)
|
||||
: Graph(model_proto, common::make_unique<GraphCache>(), std::move(extensions)) {}
|
||||
Graph::Graph(const std::string& model_dir,
|
||||
const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
ov::frontend::ExtensionHolder extensions)
|
||||
: Graph(model_dir, model_proto, common::make_unique<GraphCache>(), std::move(extensions)) {}
|
||||
|
||||
Graph::Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
Graph::Graph(const std::string& model_dir,
|
||||
const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
std::unique_ptr<GraphCache>&& cache,
|
||||
ov::frontend::ExtensionHolder extensions)
|
||||
: m_cache{std::move(cache)},
|
||||
m_extensions{std::move(extensions)} {
|
||||
m_extensions{std::move(extensions)},
|
||||
m_model_dir{model_dir} {
|
||||
const auto ops_bridge = detail::init_ops_bridge(m_extensions.conversions);
|
||||
m_model = common::make_unique<Model>(model_proto, detail::build_model_opset(*model_proto, ops_bridge));
|
||||
|
||||
@ -146,7 +150,7 @@ Graph::Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
// Process all initializers in the graph
|
||||
for (const auto& initializer_tensor : m_model->get_graph().initializer()) {
|
||||
if (initializer_tensor.has_name()) {
|
||||
Tensor tensor = Tensor{initializer_tensor};
|
||||
Tensor tensor = Tensor{initializer_tensor, m_model_dir};
|
||||
std::shared_ptr<default_opset::Constant> ng_constant;
|
||||
// For each initializer create a Constant node and store it in cache
|
||||
try {
|
||||
@ -426,8 +430,9 @@ const OpsetImports& Graph::get_opset_imports() const {
|
||||
return m_model->get_opset_imports();
|
||||
}
|
||||
|
||||
Subgraph::Subgraph(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto, const Graph* parent_graph)
|
||||
: Graph(model_proto,
|
||||
Subgraph::Subgraph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto, const Graph* parent_graph)
|
||||
: Graph(parent_graph->model_dir(),
|
||||
model_proto,
|
||||
common::make_unique<GraphCache>(),
|
||||
detail::subgraph_required_extensions(parent_graph->get_extensions())),
|
||||
m_parent_graph(parent_graph) {}
|
||||
|
@ -21,7 +21,8 @@ namespace ngraph {
|
||||
namespace onnx_import {
|
||||
class Graph : public std::enable_shared_from_this<Graph> {
|
||||
public:
|
||||
Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
Graph(const std::string& model_dir,
|
||||
const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
ov::frontend::ExtensionHolder extensions = {});
|
||||
Graph() = delete;
|
||||
|
||||
@ -36,6 +37,9 @@ public:
|
||||
const std::string& get_name() const {
|
||||
return m_model->get_graph().name();
|
||||
}
|
||||
const std::string& model_dir() const {
|
||||
return m_model_dir;
|
||||
}
|
||||
const ParameterVector& get_ng_parameters() const {
|
||||
return m_parameters;
|
||||
}
|
||||
@ -50,7 +54,8 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model,
|
||||
Graph(const std::string& model_dir,
|
||||
const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model,
|
||||
std::unique_ptr<GraphCache>&& cache,
|
||||
ov::frontend::ExtensionHolder extensions = {});
|
||||
|
||||
@ -70,6 +75,7 @@ protected:
|
||||
|
||||
private:
|
||||
std::vector<Node> m_nodes;
|
||||
std::string m_model_dir;
|
||||
};
|
||||
|
||||
/// \brief Representation of ONNX subgraph. It is used for example by ONNX Loop op.
|
||||
@ -81,7 +87,7 @@ public:
|
||||
///
|
||||
/// \param[in] model The ONNX model object.
|
||||
/// \param[in] parent_graph The reference to the parent graph.
|
||||
Subgraph(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model, const Graph* parent_graph);
|
||||
Subgraph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model, const Graph* parent_graph);
|
||||
|
||||
/// \brief Return nodes which are on the edge the subgraph and the parent graph.
|
||||
/// \return Vector of edge nodes from parent scope.
|
||||
|
@ -22,9 +22,12 @@ public:
|
||||
m_name{node_proto.has_name() ? node_proto.name() : ""},
|
||||
m_domain{get_node_domain(node_proto)},
|
||||
m_graph{&graph},
|
||||
m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())},
|
||||
m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())} {
|
||||
for (const auto& attribute : m_attributes) {
|
||||
const auto& attributes = node_proto.attribute();
|
||||
m_attributes.reserve(attributes.size());
|
||||
for (const auto& attr_proto : attributes) {
|
||||
m_attributes.emplace_back(attr_proto, m_graph->model_dir());
|
||||
const auto& attribute = m_attributes.back();
|
||||
if (attribute.is_graph())
|
||||
m_subgraphs.insert({attribute.get_name(), std::make_shared<Subgraph>(attribute.get_subgraph(m_graph))});
|
||||
}
|
||||
@ -37,9 +40,12 @@ public:
|
||||
m_name{node_proto.has_name() ? node_proto.name() : ""},
|
||||
m_domain{get_node_domain(node_proto)},
|
||||
m_graph{&graph},
|
||||
m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())},
|
||||
m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())},
|
||||
m_subgraphs(subgraphs) {}
|
||||
m_subgraphs(subgraphs) {
|
||||
for (const auto& attr_proto : node_proto.attribute()) {
|
||||
m_attributes.emplace_back(attr_proto, m_graph->model_dir());
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<Attribute>& attributes() const;
|
||||
OutputVector get_ng_inputs() const;
|
||||
|
@ -17,10 +17,10 @@ namespace onnx_import {
|
||||
class SparseTensor {
|
||||
public:
|
||||
SparseTensor() = delete;
|
||||
explicit SparseTensor(const ONNX_NAMESPACE::SparseTensorProto& sparse_tensor)
|
||||
explicit SparseTensor(const ONNX_NAMESPACE::SparseTensorProto& sparse_tensor, const std::string& model_dir)
|
||||
: m_sparse_tensor_proto{&sparse_tensor},
|
||||
m_values{sparse_tensor.values()},
|
||||
m_indices{sparse_tensor.indices()},
|
||||
m_values{sparse_tensor.values(), model_dir},
|
||||
m_indices{sparse_tensor.indices(), model_dir},
|
||||
m_shape{std::begin(sparse_tensor.dims()), std::end(sparse_tensor.dims())} {
|
||||
if (m_shape == Shape{0}) {
|
||||
// It's possible to construct a sparse tensor in ONNX with "dims: 0" property
|
||||
|
205
src/frontends/onnx/frontend/src/core/tensor.cpp
Normal file
205
src/frontends/onnx/frontend/src/core/tensor.cpp
Normal file
@ -0,0 +1,205 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "core/tensor.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
|
||||
template <>
|
||||
std::vector<double> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<double>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<double>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
|
||||
return detail::__get_data<double>(m_tensor_proto->double_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<float> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<float>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<float>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
return detail::__get_data<float>(m_tensor_proto->float_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<ngraph::float16> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<float16>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<ngraph::float16>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
using std::begin;
|
||||
using std::end;
|
||||
|
||||
const auto& int32_data = m_tensor_proto->int32_data();
|
||||
std::vector<ngraph::float16> float16_data;
|
||||
float16_data.reserve(int32_data.size());
|
||||
std::transform(begin(int32_data), end(int32_data), std::back_inserter(float16_data), [](int32_t elem) {
|
||||
return ngraph::float16::from_bits(static_cast<uint16_t>(elem));
|
||||
});
|
||||
|
||||
return detail::__get_data<ngraph::float16>(float16_data);
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<ngraph::bfloat16> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<bfloat16>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<ngraph::bfloat16>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||
return detail::__get_data<ngraph::bfloat16>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<int8_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<int8_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<int8_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
|
||||
return detail::__get_data<int8_t>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<int16_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<int16_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<int16_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
|
||||
return detail::__get_data<int16_t>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<int32_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<int32_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<int32_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
|
||||
return detail::__get_data<int32_t>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<int64_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<int64_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<int64_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
return detail::__get_data<int64_t>(m_tensor_proto->int64_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<uint8_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<uint8_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<uint8_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
return detail::__get_data<uint8_t>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<uint16_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<uint16_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<uint16_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
|
||||
return detail::__get_data<uint16_t>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<uint32_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<uint32_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<uint32_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT32) {
|
||||
return detail::__get_data<uint32_t>(m_tensor_proto->uint64_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<uint64_t> Tensor::get_data() const {
|
||||
if (has_external_data()) {
|
||||
return get_external_data<uint64_t>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<uint64_t>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
|
||||
return detail::__get_data<uint64_t>(m_tensor_proto->uint64_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<char> Tensor::get_data() const {
|
||||
// Boolean values are stored as char because std::vector<bool>
|
||||
// can behave differently from other vector containers.
|
||||
if (has_external_data()) {
|
||||
return get_external_data<char>();
|
||||
}
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return detail::__get_raw_data<char>(m_tensor_proto->raw_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
if (m_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
|
||||
return detail::__get_data<char>(m_tensor_proto->int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
@ -72,263 +72,12 @@ inline std::vector<T> __get_data(const Container& container) {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool has_tensor_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
return tensor.has_data_location() &&
|
||||
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL;
|
||||
}
|
||||
|
||||
inline std::string load_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
const auto tensor_external_data = TensorExternalData(tensor);
|
||||
return tensor_external_data.load_external_data();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> __get_raw_data(const std::string& raw_data, int onnx_data_type) {
|
||||
auto it = reinterpret_cast<const T*>(raw_data.data());
|
||||
return std::vector<T>(it, it + (raw_data.size() / onnx_common::get_onnx_data_size(onnx_data_type)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> get_external_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
return __get_raw_data<T>(load_external_data(tensor), tensor.data_type());
|
||||
}
|
||||
|
||||
inline const void* get_data_ptr(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (tensor.has_raw_data()) {
|
||||
return tensor.raw_data().data();
|
||||
}
|
||||
switch (tensor.data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return tensor.float_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
return tensor.int32_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
return tensor.int64_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
return tensor.uint64_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
return tensor.double_data().data();
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
inline size_t get_data_size(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (tensor.has_raw_data()) {
|
||||
return tensor.raw_data().size() / onnx_common::get_onnx_data_size(tensor.data_type());
|
||||
}
|
||||
switch (tensor.data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return tensor.float_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
return tensor.int32_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
return tensor.int64_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
return tensor.uint64_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
return tensor.double_data_size();
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
throw ngraph::onnx_import::error::tensor::unsupported_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<double> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<double>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<double>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
|
||||
return __get_data<double>(tensor.double_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<float> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<float>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<float>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
return __get_data<float>(tensor.float_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<ngraph::float16> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<float16>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<ngraph::float16>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
using std::begin;
|
||||
using std::end;
|
||||
|
||||
const auto& int32_data = tensor.int32_data();
|
||||
std::vector<ngraph::float16> float16_data;
|
||||
float16_data.reserve(int32_data.size());
|
||||
std::transform(begin(int32_data), end(int32_data), std::back_inserter(float16_data), [](int32_t elem) {
|
||||
return ngraph::float16::from_bits(static_cast<uint16_t>(elem));
|
||||
});
|
||||
|
||||
return __get_data<ngraph::float16>(float16_data);
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<ngraph::bfloat16> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<bfloat16>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<ngraph::bfloat16>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||
return __get_data<ngraph::bfloat16>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<int8_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<int8_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<int8_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
|
||||
return __get_data<int8_t>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<int16_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<int16_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<int16_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
|
||||
return __get_data<int16_t>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<int32_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<int32_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<int32_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
|
||||
return __get_data<int32_t>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<int64_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<int64_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<int64_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
return __get_data<int64_t>(tensor.int64_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<uint8_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<uint8_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<uint8_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
return __get_data<uint8_t>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<uint16_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<uint16_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<uint16_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
|
||||
return __get_data<uint16_t>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<uint32_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<uint32_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<uint32_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT32) {
|
||||
return __get_data<uint32_t>(tensor.uint64_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<uint64_t> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<uint64_t>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<uint64_t>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
|
||||
return __get_data<uint64_t>(tensor.uint64_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<char> get_data(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
// Boolean values are stored as char because std::vector<bool>
|
||||
// can behave differently from other vector containers.
|
||||
if (has_tensor_external_data(tensor)) {
|
||||
return get_external_data<char>(tensor);
|
||||
}
|
||||
if (tensor.has_raw_data()) {
|
||||
return __get_raw_data<char>(tensor.raw_data(), tensor.data_type());
|
||||
}
|
||||
if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
|
||||
return __get_data<char>(tensor.int32_data());
|
||||
}
|
||||
throw error::tensor::invalid_data_type{tensor.data_type()};
|
||||
}
|
||||
} // namespace
|
||||
} // namespace detail
|
||||
|
||||
@ -355,9 +104,10 @@ public:
|
||||
};
|
||||
|
||||
Tensor() = delete;
|
||||
explicit Tensor(const ONNX_NAMESPACE::TensorProto& tensor)
|
||||
explicit Tensor(const ONNX_NAMESPACE::TensorProto& tensor, const std::string& model_dir)
|
||||
: m_tensor_proto{&tensor},
|
||||
m_shape{std::begin(tensor.dims()), std::end(tensor.dims())} {
|
||||
m_shape{std::begin(tensor.dims()), std::end(tensor.dims())},
|
||||
m_model_dir{model_dir} {
|
||||
if (m_shape == Shape{0}) {
|
||||
// It's possible to construct a tensor in ONNX with "dims: 0" property
|
||||
// Such tensor contains a scalar. This results in a Shape{0} stored in m_shape.
|
||||
@ -380,7 +130,7 @@ public:
|
||||
if (m_tensor_proto->has_segment()) {
|
||||
throw error::tensor::segments_unsupported{};
|
||||
}
|
||||
return detail::get_data<T>(*m_tensor_proto);
|
||||
throw ngraph::onnx_import::error::tensor::unsupported_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
const std::string& get_name() const {
|
||||
@ -483,12 +233,12 @@ private:
|
||||
bool>::type = true>
|
||||
std::shared_ptr<ngraph::op::Constant> make_ng_constant(const element::Type& type) const {
|
||||
std::shared_ptr<default_opset::Constant> constant{nullptr};
|
||||
int data_size = detail::get_data_size(*m_tensor_proto);
|
||||
if (detail::has_tensor_external_data(*m_tensor_proto)) {
|
||||
auto external_data = detail::load_external_data(*m_tensor_proto);
|
||||
int data_size = get_data_size();
|
||||
if (has_external_data()) {
|
||||
auto external_data = load_external_data();
|
||||
constant = std::make_shared<ngraph::op::Constant>(type, m_shape, external_data.data());
|
||||
} else if (data_size == shape_size(m_shape)) {
|
||||
constant = std::make_shared<ngraph::op::Constant>(type, m_shape, detail::get_data_ptr(*m_tensor_proto));
|
||||
constant = std::make_shared<ngraph::op::Constant>(type, m_shape, get_data_ptr());
|
||||
} else if (data_size == 0 && m_shape.size() == 0) {
|
||||
constant = common::make_failsafe_constant(type);
|
||||
} else {
|
||||
@ -523,12 +273,107 @@ private:
|
||||
return constant;
|
||||
}
|
||||
|
||||
bool has_external_data() const {
|
||||
return m_tensor_proto->has_data_location() &&
|
||||
m_tensor_proto->data_location() ==
|
||||
ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL;
|
||||
}
|
||||
|
||||
std::string load_external_data() const {
|
||||
const auto tensor_external_data = detail::TensorExternalData(*m_tensor_proto);
|
||||
return tensor_external_data.load_external_data(m_model_dir);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> get_external_data() const {
|
||||
return detail::__get_raw_data<T>(load_external_data(), m_tensor_proto->data_type());
|
||||
}
|
||||
|
||||
const void* get_data_ptr() const {
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return m_tensor_proto->raw_data().data();
|
||||
}
|
||||
switch (m_tensor_proto->data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return m_tensor_proto->float_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
return m_tensor_proto->int32_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
return m_tensor_proto->int64_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
return m_tensor_proto->uint64_data().data();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
return m_tensor_proto->double_data().data();
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
size_t get_data_size() const {
|
||||
if (m_tensor_proto->has_raw_data()) {
|
||||
return m_tensor_proto->raw_data().size() / onnx_common::get_onnx_data_size(m_tensor_proto->data_type());
|
||||
}
|
||||
switch (m_tensor_proto->data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return m_tensor_proto->float_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
return m_tensor_proto->int32_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
return m_tensor_proto->int64_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
return m_tensor_proto->uint64_data_size();
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
return m_tensor_proto->double_data_size();
|
||||
}
|
||||
throw error::tensor::invalid_data_type{m_tensor_proto->data_type()};
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* m_tensor_proto;
|
||||
Shape m_shape;
|
||||
std::string m_model_dir;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& outs, const Tensor& tensor) {
|
||||
return (outs << "<Tensor: " << tensor.get_name() << ">");
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<double> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<float> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<ngraph::float16> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<ngraph::bfloat16> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<int8_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<int16_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<int32_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<int64_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<uint8_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<uint16_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<uint32_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<uint64_t> Tensor::get_data() const;
|
||||
|
||||
template <>
|
||||
std::vector<char> Tensor::get_data() const;
|
||||
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "core/model.hpp"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ops_bridge.hpp"
|
||||
|
||||
@ -106,34 +105,6 @@ void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::Model
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::onnx_import::transform::update_external_data_paths(ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::string& model_path) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
if (model_path.empty()) {
|
||||
return;
|
||||
}
|
||||
const auto model_dir_path = file_util::get_directory(model_path);
|
||||
auto graph_proto = model_proto.mutable_graph();
|
||||
for (auto& initializer_tensor : *graph_proto->mutable_initializer()) {
|
||||
const auto location_key_value_index = 0;
|
||||
if (initializer_tensor.has_data_location() &&
|
||||
initializer_tensor.data_location() ==
|
||||
ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) {
|
||||
const auto external_data_relative_path = initializer_tensor.external_data(location_key_value_index).value();
|
||||
const auto santized_external_data_relative_path = file_util::sanitize_path(external_data_relative_path);
|
||||
auto external_data_full_path = file_util::path_join(model_dir_path, santized_external_data_relative_path);
|
||||
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
file_util::convert_path_win_style(external_data_full_path);
|
||||
#endif
|
||||
|
||||
// Set full paths to the external file
|
||||
initializer_tensor.mutable_external_data(location_key_value_index)->set_value(external_data_full_path);
|
||||
}
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void ngraph::onnx_import::transform::fixup_legacy_operators(ONNX_NAMESPACE::ModelProto& model_proto) {
|
||||
auto graph_proto = model_proto.mutable_graph();
|
||||
for (auto& node : *graph_proto->mutable_node()) {
|
||||
|
@ -10,16 +10,6 @@ namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace transform {
|
||||
|
||||
/// \brief Replace external_data path in tensors with full path to data file.
|
||||
///
|
||||
/// Paths to external data files are stored as relative to model path.
|
||||
/// This transformation replaces them with a full filesystem path.
|
||||
/// As a result in further processing data from external files can be read directly.
|
||||
///
|
||||
/// \param model_proto Protobuf message with ONNX model to transform.
|
||||
/// \param model_path Filesystem path to the ONNX model file.
|
||||
void update_external_data_paths(ONNX_NAMESPACE::ModelProto& model_proto, const std::string& model_path);
|
||||
|
||||
static const std::vector<std::string> onnx_functions_to_expand =
|
||||
{"Bernoulli", "Celu", "GreaterOrEqual", "LessOrEqual", "NegativeLogLikelihoodLoss", "SoftmaxCrossEntropyLoss"};
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "core/graph.hpp"
|
||||
#include "core/model.hpp"
|
||||
#include "core/transform.hpp"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "onnx_framework_node.hpp"
|
||||
#include "onnx_import/core/null_node.hpp"
|
||||
|
||||
@ -51,9 +52,8 @@ void remove_dangling_results(std::shared_ptr<Function>& function) {
|
||||
}
|
||||
}
|
||||
|
||||
void apply_transformations(ONNX_NAMESPACE::ModelProto& model_proto, const std::string& model_path) {
|
||||
void apply_transformations(ONNX_NAMESPACE::ModelProto& model_proto) {
|
||||
transform::fixup_legacy_operators(model_proto);
|
||||
transform::update_external_data_paths(model_proto, model_path);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -89,16 +89,20 @@ void convert_decoded_function(std::shared_ptr<Function> function) {
|
||||
std::shared_ptr<Function> import_onnx_model(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto,
|
||||
const std::string& model_path,
|
||||
ov::frontend::ExtensionHolder extensions) {
|
||||
apply_transformations(*model_proto, model_path);
|
||||
Graph graph{model_proto, std::move(extensions)};
|
||||
apply_transformations(*model_proto);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
Graph graph{file_util::get_directory(model_path), model_proto, std::move(extensions)};
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
return graph.convert();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> decode_to_framework_nodes(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto,
|
||||
const std::string& model_path,
|
||||
ov::frontend::ExtensionHolder extensions) {
|
||||
apply_transformations(*model_proto, model_path);
|
||||
auto graph = std::make_shared<Graph>(model_proto, extensions);
|
||||
apply_transformations(*model_proto);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
auto graph = std::make_shared<Graph>(file_util::get_directory(model_path), model_proto, extensions);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
return graph->decode();
|
||||
}
|
||||
} // namespace detail
|
||||
|
@ -17,26 +17,33 @@ namespace onnx_import {
|
||||
namespace detail {
|
||||
TensorExternalData::TensorExternalData(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
for (const auto& entry : tensor.external_data()) {
|
||||
if (entry.key() == "location")
|
||||
m_data_location = entry.value();
|
||||
if (entry.key() == "offset")
|
||||
if (entry.key() == "location") {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
m_data_location = file_util::sanitize_path(entry.value());
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
} else if (entry.key() == "offset") {
|
||||
m_offset = std::stoi(entry.value());
|
||||
if (entry.key() == "length")
|
||||
} else if (entry.key() == "length") {
|
||||
m_data_length = std::stoi(entry.value());
|
||||
if (entry.key() == "checksum")
|
||||
} else if (entry.key() == "checksum") {
|
||||
m_sha1_digest = std::stoi(entry.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string TensorExternalData::load_external_data() const {
|
||||
std::string TensorExternalData::load_external_data(const std::string& model_dir) const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
auto full_path = file_util::path_join(model_dir, m_data_location);
|
||||
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
std::wstring path = ov::util::string_to_wstring(m_data_location);
|
||||
file_util::convert_path_win_style(full_path);
|
||||
std::ifstream external_data_stream(ov::util::string_to_wstring(full_path),
|
||||
std::ios::binary | std::ios::in | std::ios::ate);
|
||||
#else
|
||||
std::string path = m_data_location;
|
||||
std::ifstream external_data_stream(full_path, std::ios::binary | std::ios::in | std::ios::ate);
|
||||
#endif
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
std::ifstream external_data_stream(path, std::ios::binary | std::ios::in | std::ios::ate);
|
||||
|
||||
if (external_data_stream.fail())
|
||||
throw error::invalid_external_data{*this};
|
||||
|
||||
|
@ -21,7 +21,7 @@ public:
|
||||
/// the invalid_external_data exception is thrown.
|
||||
///
|
||||
/// \return External binary data loaded into a std::string
|
||||
std::string load_external_data() const;
|
||||
std::string load_external_data(const std::string& model_dir) const;
|
||||
|
||||
/// \brief Represets parameter of external data as string
|
||||
///
|
||||
|
@ -0,0 +1,66 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
output: "B"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 2
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "const_tensor"
|
||||
external_data {
|
||||
key: "location",
|
||||
value: "tensors_data/tensor.data"
|
||||
}
|
||||
data_location: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "A"
|
||||
input: "B"
|
||||
output: "X"
|
||||
name: "add_node1"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "test_graph"
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 4
|
||||
}
|
@ -138,3 +138,16 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_external_data_sanitize_path) {
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_external_data_in_constant_node) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/external_data/external_data_in_constant_node.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<float>({3.f, 5.f, 8.f, 13.f});
|
||||
test_case.add_expected_output<float>(Shape{2, 2}, {4.f, 7.f, 11.f, 17.f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user