Add ONNX model metadata to ov::Model (#13712)

* Add ONNX model metadata to ov::Model

* Add correct path to onnx models
This commit is contained in:
Artur Kulikowski 2022-10-31 08:15:26 +01:00 committed by GitHub
parent af5e06bb8b
commit 0aeb3d8151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 175 additions and 1 deletions

View File

@ -262,10 +262,21 @@ void Graph::remove_dangling_parameters() {
}
}
void Graph::set_metadata(std::shared_ptr<ov::Model>& model) const {
const std::string framework_section = "framework";
const auto metadata = m_model->get_metadata();
for (const auto& pair : metadata) {
model->set_rt_info(pair.second, framework_section, pair.first);
}
}
std::shared_ptr<Function> Graph::convert() {
convert_to_ngraph_nodes();
remove_dangling_parameters();
return create_function();
auto function = create_function();
set_metadata(function);
return function;
}
OutputVector Graph::make_framework_nodes(const Node& onnx_node) {

View File

@ -66,6 +66,7 @@ protected:
void decode_to_framework_nodes();
void convert_to_ngraph_nodes();
void remove_dangling_parameters();
void set_metadata(std::shared_ptr<ov::Model>& model) const;
std::shared_ptr<Function> create_function();
ParameterVector m_parameters;

View File

@ -54,6 +54,16 @@ public:
return m_model_proto->producer_version();
}
std::map<std::string, std::string> get_metadata() const {
std::map<std::string, std::string> metadata;
const auto& model_metadata = m_model_proto->metadata_props();
for (const auto& prop : model_metadata) {
metadata.emplace(prop.key(), prop.value());
}
return metadata;
}
/// \brief Access an operator object by its type name and domain name
/// The function will return the operator object if it exists, or report an error
/// in case of domain or operator absence.

View File

@ -0,0 +1,130 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "out"
op_type: "PriorBoxClustered"
attribute {
name: "width"
floats: 0.10000000149011612
floats: 0.10000000149011612
floats: 0.20000000298023224
floats: 0.20000000298023224
type: FLOATS
}
attribute {
name: "height"
floats: 0.10000000149011612
floats: 0.10000000149011612
floats: 0.20000000298023224
floats: 0.20000000298023224
type: FLOATS
}
attribute {
name: "step_w"
f: 64
type: FLOAT
}
attribute {
name: "clip"
i: 1
type: INT
}
attribute {
name: "step_h"
f: 64
type: FLOAT
}
attribute {
name: "offset"
f: 0.5
type: FLOAT
}
attribute {
name: "variance"
floats: 0.10000000149011612
floats: 0.10000000149011612
floats: 0.20000000298023224
floats: 0.20000000298023224
type: FLOATS
}
domain: "org.openvinotoolkit"
}
name: "compute_graph"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 1
}
metadata_props {
key: "meta_key1"
value: "meta_value1"
}
metadata_props {
key: "meta_key2"
value: "meta_value2"
}

View File

@ -10,6 +10,7 @@
#include "common_test_utils/file_utils.hpp"
#include "onnx_import/onnx.hpp"
#include "openvino/openvino.hpp"
#include "openvino/util/file_util.hpp"
TEST(ONNX_Importer_Tests, ImportBasicModel) {
@ -159,3 +160,24 @@ TEST(ONNX_Importer_Tests, IsOperatorSupported) {
ASSERT_TRUE(is_abs_op_supported);
}
TEST(ONNX_Importer_Tests, ImportModelWithoutMetadata) {
ov::Core core;
auto model = core.read_model(
CommonTestUtils::getModelFromTestModelZoo(ov::util::path_join({ONNX_MODELS_DIR, "priorbox_clustered.onnx"})));
ASSERT_FALSE(model->has_rt_info("framework"));
}
TEST(ONNX_Importer_Tests, ImportModelWithMetadata) {
ov::Core core;
auto model = core.read_model(
CommonTestUtils::getModelFromTestModelZoo(ov::util::path_join({ONNX_MODELS_DIR, "model_with_metadata.onnx"})));
ASSERT_TRUE(model->has_rt_info("framework"));
const auto rtinfo = model->get_rt_info();
auto metadata = rtinfo.at("framework").as<ov::AnyMap>();
ASSERT_EQ(metadata.size(), 2);
ASSERT_EQ(metadata["meta_key1"].as<std::string>(), "meta_value1");
ASSERT_EQ(metadata["meta_key2"].as<std::string>(), "meta_value2");
}