Provide ONNX external data mechanism to ReadNetwork (#2588)
* added unit test * added python test * using pword approach * Added passing path to onnx reader * support for wstring * Added more tests * Apply suggestions from code review Co-authored-by: Michał Karzyński <4430709+postrational@users.noreply.github.com> * fix build for Windows * styles applied * Fixed Windows tests * styles applied * fixed styles in tests * review remarks * cmake order * Used target_compile_definitions instead of add_definitions * Move ONNX_TEST_MODELS to other scope Co-authored-by: Michał Karzyński <4430709+postrational@users.noreply.github.com>
This commit is contained in:
@@ -52,6 +52,8 @@ if(TARGET inference_engine_onnx_reader)
|
||||
add_dependencies(${TARGET_NAME} inference_engine_onnx_reader)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(${TARGET_NAME} PRIVATE ONNX_TEST_MODELS="${CMAKE_CURRENT_SOURCE_DIR}/onnx_reader/models/")
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
#
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,97 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
input: "B"
|
||||
output: "X"
|
||||
name: "add_node1"
|
||||
op_type: "Add"
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "C"
|
||||
output: "Y"
|
||||
name: "add_node2"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "test_graph"
|
||||
initializer {
|
||||
dims: 2
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "A"
|
||||
external_data {
|
||||
key: "location",
|
||||
value: "data/tensor.data"
|
||||
}
|
||||
data_location: 1
|
||||
}
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "C"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 4
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
input: "B"
|
||||
output: "X"
|
||||
name: "multiply_node_1"
|
||||
op_type: "Mul"
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "C"
|
||||
output: "Y"
|
||||
name: "multiply_node_2"
|
||||
op_type: "Mul"
|
||||
}
|
||||
name: "test_graph"
|
||||
initializer {
|
||||
dims: 2
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "A"
|
||||
external_data {
|
||||
key: "location",
|
||||
value: "../data/tensor.data"
|
||||
}
|
||||
data_location: 1
|
||||
}
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "C"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 4
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
|
||||
#include <ie_blob.h>
|
||||
#include <ie_core.hpp>
|
||||
#include <file_utils.h>
|
||||
#include <streambuf>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
|
||||
TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromFile) {
|
||||
InferenceEngine::Core ie;
|
||||
auto cnnNetwork = ie.ReadNetwork(std::string(ONNX_TEST_MODELS) + "onnx_external_data.prototxt", "");
|
||||
auto function = cnnNetwork.getFunction();
|
||||
|
||||
int count_additions = 0;
|
||||
int count_constants = 0;
|
||||
int count_parameters = 0;
|
||||
|
||||
std::shared_ptr<ngraph::Node> external_data_node;
|
||||
for (auto op : function->get_ops()) {
|
||||
const auto op_type = std::string(op->get_type_name());
|
||||
count_additions += (op_type == "Add" ? 1 : 0);
|
||||
count_parameters += (op_type == "Parameter" ? 1 : 0);
|
||||
if (op_type == "Constant") {
|
||||
count_constants += 1;
|
||||
external_data_node = op;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(function->get_output_size(), 1);
|
||||
ASSERT_EQ(std::string(function->get_output_op(0)->get_type_name()), "Result");
|
||||
ASSERT_EQ(function->get_output_element_type(0), ngraph::element::f32);
|
||||
ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({2, 2}));
|
||||
ASSERT_EQ(count_additions, 2);
|
||||
ASSERT_EQ(count_constants, 1);
|
||||
ASSERT_EQ(count_parameters, 2);
|
||||
|
||||
const auto external_data_node_const = ngraph::as_type_ptr<ngraph::op::Constant>(external_data_node);
|
||||
ASSERT_TRUE(external_data_node_const->get_vector<float>() == (std::vector<float>{1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromStringException) {
|
||||
InferenceEngine::Core ie;
|
||||
const auto path = std::string(ONNX_TEST_MODELS) + "onnx_external_data.prototxt";
|
||||
InferenceEngine::Blob::CPtr weights; //not used
|
||||
std::ifstream stream(path, std::ios::binary);
|
||||
std::string modelAsString((std::istreambuf_iterator<char>(stream)), std::istreambuf_iterator<char>());
|
||||
stream.close();
|
||||
try {
|
||||
auto cnnNetwork = ie.ReadNetwork(modelAsString, weights);
|
||||
}
|
||||
catch(const ngraph::ngraph_error& e) {
|
||||
EXPECT_PRED_FORMAT2(
|
||||
testing::IsSubstring,
|
||||
std::string("invalid external data:"),
|
||||
e.what());
|
||||
|
||||
EXPECT_PRED_FORMAT2(
|
||||
testing::IsSubstring,
|
||||
std::string("data/tensor.data, offset: 0, data_lenght: 0, sha1_digest: 0)"),
|
||||
e.what());
|
||||
}
|
||||
catch(...) {
|
||||
FAIL() << "Reading network failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromWstringNamedFile) {
|
||||
InferenceEngine::Core ie;
|
||||
std::string win_dir_path = ONNX_TEST_MODELS;
|
||||
std::replace(win_dir_path.begin(), win_dir_path.end(), '/', '\\');
|
||||
const std::wstring unicode_win_dir_path = FileUtils::multiByteCharToWString(win_dir_path.c_str());
|
||||
const std::wstring path = unicode_win_dir_path + L"АБВГДЕЁЖЗИЙ\\ひらがな日本語.prototxt";
|
||||
|
||||
auto cnnNetwork = ie.ReadNetwork(path, L"");
|
||||
auto function = cnnNetwork.getFunction();
|
||||
|
||||
int count_multiply = 0;
|
||||
int count_constants = 0;
|
||||
int count_parameters = 0;
|
||||
|
||||
std::shared_ptr<ngraph::Node> external_data_node;
|
||||
for (auto op : function->get_ops()) {
|
||||
const auto op_type = std::string(op->get_type_name());
|
||||
count_multiply += (op_type == "Multiply" ? 1 : 0);
|
||||
count_parameters += (op_type == "Parameter" ? 1 : 0);
|
||||
if (op_type == "Constant") {
|
||||
count_constants += 1;
|
||||
external_data_node = op;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(function->get_output_size(), 1);
|
||||
ASSERT_EQ(std::string(function->get_output_op(0)->get_type_name()), "Result");
|
||||
ASSERT_EQ(function->get_output_element_type(0), ngraph::element::f32);
|
||||
ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({2, 2}));
|
||||
ASSERT_EQ(count_multiply, 2);
|
||||
ASSERT_EQ(count_constants, 1);
|
||||
ASSERT_EQ(count_parameters, 2);
|
||||
|
||||
const auto external_data_node_const = ngraph::as_type_ptr<ngraph::op::Constant>(external_data_node);
|
||||
ASSERT_TRUE(external_data_node_const->get_vector<float>() == (std::vector<float>{1, 2, 3, 4}));
|
||||
}
|
||||
#endif
|
||||
Reference in New Issue
Block a user