[ONNX Editor] Add initializers for model inputs (#4025)

This commit is contained in:
Tomasz Jankowski
2021-02-16 14:12:27 +01:00
committed by GitHub
parent 0a36e4e810
commit 5e17926604
10 changed files with 442 additions and 30 deletions

View File

@@ -20,6 +20,7 @@
#include <map>
#include <memory>
#include "ngraph/op/constant.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "onnx_import/utils/onnx_importer_visibility.hpp"
@@ -70,6 +71,20 @@ namespace ngraph
/// the inputs specified in its parameter.
void set_input_shapes(const std::map<std::string, ngraph::PartialShape>& input_shapes);
/// \brief Modifies the in-memory representation of the model by setting custom input
/// values for inputs specified in the provided map.
///
/// \note This method modifies existing initializer tensor if its name matches one of
/// input_name. Otherwise it adds initializer tensor into the model.
/// If input tensor of matching name is present in the model, its type and shape
/// are modified accordingly.
///
/// \param input_values A collection of pairs {input_name: new_input_values} used to
/// update the ONNX model. Initializers already existing are
/// overwritten.
void set_input_values(
const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values);
/// \brief Returns a non-const reference to the underlying ModelProto object, possibly
/// modified by the editor's API calls
///

View File

@@ -23,6 +23,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
#include "utils/tensor_external_data.hpp"
namespace ngraph
@@ -129,42 +130,15 @@ namespace ngraph
#endif
}
/// Returns the size if bytes of an ONNX data type.
inline size_t __get_onnx_data_size(int data_type)
{
switch (data_type)
{
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: return sizeof(float);
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: return sizeof(uint8_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT8: return sizeof(int8_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
return sizeof(uint16_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT16: return sizeof(int16_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT32: return sizeof(int32_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT64: return sizeof(int64_t);
case ONNX_NAMESPACE::TensorProto_DataType_BOOL: return sizeof(char);
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: return sizeof(double);
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
return sizeof(uint32_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
return sizeof(uint64_t);
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64:
return 2 * sizeof(float);
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128:
return 2 * sizeof(double);
default: NGRAPH_UNREACHABLE("Unsupported data type");
}
}
template <typename T>
inline std::vector<T> __get_raw_data(const std::string& raw_data,
int onnx_data_type)
{
auto it = reinterpret_cast<const T*>(raw_data.data());
return std::vector<T>(
it, it + (raw_data.size() / __get_onnx_data_size(onnx_data_type)));
it,
it +
(raw_data.size() / common::get_onnx_data_size(onnx_data_type)));
}
template <typename T>

View File

@@ -17,7 +17,9 @@
#include <fstream>
#include <onnx/onnx_pb.h>
#include "ngraph/log.hpp"
#include "onnx_import/editor/editor.hpp"
#include "utils/common.hpp"
#include "utils/parser.hpp"
using namespace ngraph;
@@ -55,6 +57,18 @@ namespace
return nullptr;
}
TensorProto* find_graph_initializer(GraphProto& graph, const std::string& name)
{
for (int i = 0; i < graph.initializer_size(); ++i)
{
auto* initializer_desc = graph.mutable_initializer(i);
if (initializer_desc->has_name() && initializer_desc->name() == name)
return initializer_desc;
}
return nullptr;
}
void modify_input_type(ValueInfoProto& onnx_input, const element::Type_t elem_type)
{
if (!onnx_input.has_type())
@@ -142,6 +156,48 @@ namespace
*(tensor_type->mutable_shape()) = std::move(new_onnx_shape);
}
}
void modify_initializer(TensorProto& initializer,
const std::string& name,
const std::shared_ptr<ngraph::op::Constant> values,
ValueInfoProto* input)
{
const auto elem_type = values->get_element_type();
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
{
throw ngraph_error("Initializer '" + name + "' type cannot be set to: " +
element::Type(elem_type).get_type_name() +
". This type is not allowed in ONNX.");
}
initializer.Clear();
initializer.set_name(name);
initializer.set_data_type(NG_2_ONNX_TYPES.at(values->get_element_type()));
for (const auto& dim : values->get_shape())
{
initializer.add_dims(dim);
}
const auto data_size_in_bytes =
shape_size(values->get_shape()) *
onnx_import::common::get_onnx_data_size(initializer.data_type());
initializer.set_raw_data(values->get_data_ptr(), data_size_in_bytes);
// update input with type and shape of initializer
if (input)
{
auto tensor_type = input->mutable_type()->mutable_tensor_type();
TensorShapeProto shape;
for (size_t i = 0; i < initializer.dims_size(); ++i)
{
shape.add_dim()->set_dim_value(initializer.dims(i));
}
*tensor_type->mutable_shape() = std::move(shape);
tensor_type->set_elem_type(initializer.data_type());
}
}
} // namespace
/// \brief A helper class used to hold the ModelProto object as its field
@@ -232,3 +288,31 @@ void onnx_import::ONNXModelEditor::set_input_shapes(
}
}
}
void onnx_import::ONNXModelEditor::set_input_values(
const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values)
{
auto onnx_graph = m_pimpl->m_model_proto.mutable_graph();
for (const auto& input : input_values)
{
auto& name = input.first;
auto& values = input.second;
auto onnx_input = find_graph_input(*onnx_graph, name);
auto onnx_initializer = find_graph_initializer(*onnx_graph, name);
if (!onnx_initializer && !onnx_input)
{
NGRAPH_INFO << "There is no input nor initializer named '" << name
<< "' in original model '" << m_model_path << "'.";
}
if (!onnx_initializer)
{
onnx_initializer = onnx_graph->add_initializer();
}
modify_initializer(*onnx_initializer, name, values, onnx_input);
}
}

View File

@@ -53,6 +53,35 @@ namespace ngraph
#endif
}
size_t get_onnx_data_size(int32_t onnx_type)
{
switch (onnx_type)
{
case ONNX_NAMESPACE::TensorProto_DataType_BOOL: return sizeof(char);
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: return 2 * sizeof(double);
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: return 2 * sizeof(float);
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: return sizeof(double);
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: return sizeof(float);
case ONNX_NAMESPACE::TensorProto_DataType_INT8: return sizeof(int8_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT16: return sizeof(int16_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT32: return sizeof(int32_t);
case ONNX_NAMESPACE::TensorProto_DataType_INT64: return sizeof(int64_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: return sizeof(uint8_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: return sizeof(uint16_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return sizeof(uint32_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return sizeof(uint64_t);
}
#ifdef NGRAPH_USE_PROTOBUF_LITE
throw ngraph_error("unsupported element type");
#else
throw ngraph_error(
"unsupported element type: " +
ONNX_NAMESPACE::TensorProto_DataType_Name(
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
#endif
}
std::shared_ptr<ngraph::Node> get_monotonic_range_along_node_rank(
const Output<ngraph::Node>& value, int64_t start_value, int64_t step)
{

View File

@@ -39,6 +39,8 @@ namespace ngraph
{
const ngraph::element::Type& get_ngraph_element_type(std::int64_t onnx_type);
size_t get_onnx_data_size(int32_t onnx_type);
/// \brief Return a monotonic sequence.
///
/// \note Limitations: this function may not work for very large integer values

View File

@@ -0,0 +1,54 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "X"
name: "add_node"
op_type: "Add"
}
name: "test_graph"
input {
name: "A"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "X"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@@ -0,0 +1,28 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "X"
name: "add_node"
op_type: "Add"
}
name: "test_graph"
output {
name: "X"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@@ -0,0 +1,68 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "X"
name: "add_node"
op_type: "Add"
}
name: "test_graph"
initializer {
dims: 2
data_type: 7
int64_data: 1
int64_data: 2
name: "A"
}
initializer {
dims: 2
data_type: 7
int64_data: 1
int64_data: 2
name: "B"
}
input {
name: "A"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "X"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@@ -0,0 +1,42 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "X"
name: "add_node"
op_type: "Add"
}
name: "test_graph"
initializer {
dims: 2
data_type: 7
int64_data: 1
int64_data: 2
name: "A"
}
initializer {
dims: 2
data_type: 7
int64_data: 1
int64_data: 2
name: "B"
}
output {
name: "X"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@@ -22,6 +22,8 @@
#include "ngraph/op/util/op_types.hpp"
#include "onnx_import/editor/editor.hpp"
#include "onnx_import/onnx.hpp"
#include "util/engine/interpreter_engine.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
@@ -284,3 +286,117 @@ NGRAPH_TEST(onnx_editor, shapes__static_to_dynamic_rank_substitution)
EXPECT_TRUE(input->get_partial_shape().same_scheme(new_shape));
}
}
using TestEngine = test::INTERPRETER_Engine;
NGRAPH_TEST(onnx_editor, values__append_one_initializer)
{
onnx_import::ONNXModelEditor editor{
file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_1D.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("A", op::Constant::create(element::i64, Shape{2}, {1, 2}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int64_t>(Shape{2}, {5, 6});
test_case.add_expected_output<int64_t>(Shape{2}, {6, 8});
test_case.run();
}
NGRAPH_TEST(onnx_editor, values__append_two_initializers_to_invalid)
{
onnx_import::ONNXModelEditor editor{
file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_1D_invalid.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("A", op::Constant::create(element::i64, Shape{2}, {4, 2}));
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {1, 3}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {5, 5});
test_case.run();
}
NGRAPH_TEST(onnx_editor, values__modify_one_initializer)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/add_1D_with_initializers.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {3, 4}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {4, 6});
test_case.run();
}
NGRAPH_TEST(onnx_editor, values__modify_two_initializers)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/add_1D_with_initializers.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("A", op::Constant::create(element::i64, Shape{2}, {3, 6}));
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {2, 1}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {5, 7});
test_case.run();
}
NGRAPH_TEST(onnx_editor, values__no_inputs_modify_two_initializers)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/add_1D_with_initializers_only.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("A", op::Constant::create(element::i64, Shape{2}, {1, 2}));
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {11, 22}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {12, 24});
test_case.run();
}
NGRAPH_TEST(onnx_editor, values__append_two_initializers_change_shape_type)
{
onnx_import::ONNXModelEditor editor{
file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_1D.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("A", op::Constant::create(element::i8, Shape{2, 1}, {-1, 1}));
in_vals.emplace("B", op::Constant::create(element::i8, Shape{2, 1}, {-2, 2}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int8_t>(Shape{2, 1}, {-3, 3});
test_case.run();
}
NGRAPH_TEST(onnx_editor, values__append_two_initializers_mixed_types)
{
onnx_import::ONNXModelEditor editor{
file_util::path_join(SERIALIZED_ZOO, "onnx/gather_elements_float_3D_axis_2.prototxt")};
std::map<std::string, std::shared_ptr<ngraph::op::Constant>> in_vals;
in_vals.emplace("data",
op::Constant::create(element::i16, Shape{2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}));
in_vals.emplace("indices", op::Constant::create(element::i32, Shape{2, 2, 1}, {0, 1, 0, 1}));
editor.set_input_values(in_vals);
const auto function = onnx_import::import_onnx_model(editor);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int16_t>(Shape{2, 2, 1}, {1, 4, 5, 8});
test_case.run();
}