Provide ONNX external data mechanism to ReadNetwork (#2588)
* added unit test * added python test * using pword approach * Added passing path to onnx reader * support for wstring * Added more tests * Apply suggestions from code review Co-authored-by: Michał Karzyński <4430709+postrational@users.noreply.github.com> * fix build for Windows * styles applied * Fixed Windows tests * styles applied * fixed styles in tests * review remarks * cmake order * Used target_compile_definitions instead of add_definitions * Move ONNX_TEST_MODELS to other scope Co-authored-by: Michał Karzyński <4430709+postrational@users.noreply.github.com>
This commit is contained in:
parent
9956639531
commit
c0d71900fd
@ -57,7 +57,8 @@ public:
|
||||
* For IR format (*.bin):
|
||||
* * if path is empty, will try to read bin file with the same name as xml and
|
||||
* * if bin file with the same name was not found, will load IR without weights.
|
||||
* ONNX models with data files are not supported
|
||||
* For ONNX format (*.onnx or *.prototxt):
|
||||
* * binPath parameter is not used.
|
||||
* @return CNNNetwork
|
||||
*/
|
||||
CNNNetwork ReadNetwork(const std::wstring& modelPath, const std::wstring& binPath = {}) const;
|
||||
@ -70,7 +71,8 @@ public:
|
||||
* For IR format (*.bin):
|
||||
* * if path is empty, will try to read bin file with the same name as xml and
|
||||
* * if bin file with the same name was not found, will load IR without weights.
|
||||
* ONNX models with data files are not supported
|
||||
* For ONNX format (*.onnx or *.prototxt):
|
||||
* * binPath parameter is not used.
|
||||
* @return CNNNetwork
|
||||
*/
|
||||
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath = {}) const;
|
||||
@ -78,7 +80,10 @@ public:
|
||||
* @brief Reads models from IR and ONNX formats
|
||||
* @param model string with model in IR or ONNX format
|
||||
* @param weights shared pointer to constant blob with weights
|
||||
* ONNX models doesn't support models with data blobs.
|
||||
* Reading ONNX models doesn't support loading weights from data blobs.
|
||||
* If you are using an ONNX model with external data files, please use the
|
||||
* `InferenceEngine::Core::ReadNetwork(const std::string& model, const Blob::CPtr& weights) const`
|
||||
* function overload which takes a filesystem path to the model.
|
||||
* For ONNX case the second parameter should contain empty blob.
|
||||
* @return CNNNetwork
|
||||
*/
|
||||
|
@ -168,6 +168,10 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string&
|
||||
#endif
|
||||
// Try to open model file
|
||||
std::ifstream modelStream(model_path, std::ios::binary);
|
||||
// save path in extensible array of stream
|
||||
// notice: lifetime of path pointed by pword(0) is limited by current scope
|
||||
const std::string path_to_save_in_stream = modelPath;
|
||||
modelStream.pword(0) = const_cast<char*>(path_to_save_in_stream.c_str());
|
||||
if (!modelStream.is_open())
|
||||
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
|
||||
|
||||
|
@ -26,6 +26,9 @@ CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath,
|
||||
* @param model string with IR
|
||||
* @param weights shared pointer to constant blob with weights
|
||||
* @param exts vector with extensions
|
||||
* @note Reading ONNX models doesn't support loading weights from data blobs.
|
||||
If you are using an ONNX model with external data files, please use the
|
||||
ReadNetwork function overload which takes a filesystem path to the model.
|
||||
* @return CNNNetwork
|
||||
*/
|
||||
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts);
|
||||
|
@ -21,8 +21,18 @@ bool ONNXReader::supportModel(std::istream& model) const {
|
||||
return !((header.find("<net ") != std::string::npos) || (header.find("<Net ") != std::string::npos));
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::string readPathFromStream(std::istream& stream) {
|
||||
if (stream.pword(0) == nullptr) {
|
||||
return {};
|
||||
}
|
||||
// read saved path from extensible array
|
||||
return std::string{static_cast<char*>(stream.pword(0))};
|
||||
}
|
||||
}
|
||||
|
||||
CNNNetwork ONNXReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
|
||||
return CNNNetwork(ngraph::onnx_import::import_onnx_model(model));
|
||||
return CNNNetwork(ngraph::onnx_import::import_onnx_model(model, readPathFromStream(model)));
|
||||
}
|
||||
|
||||
INFERENCE_PLUGIN_API(StatusCode) InferenceEngine::CreateReader(IReader*& reader, ResponseDesc *resp) noexcept {
|
||||
|
@ -52,6 +52,8 @@ if(TARGET inference_engine_onnx_reader)
|
||||
add_dependencies(${TARGET_NAME} inference_engine_onnx_reader)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(${TARGET_NAME} PRIVATE ONNX_TEST_MODELS="${CMAKE_CURRENT_SOURCE_DIR}/onnx_reader/models/")
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
#
|
||||
|
Binary file not shown.
@ -0,0 +1,97 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
input: "B"
|
||||
output: "X"
|
||||
name: "add_node1"
|
||||
op_type: "Add"
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "C"
|
||||
output: "Y"
|
||||
name: "add_node2"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "test_graph"
|
||||
initializer {
|
||||
dims: 2
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "A"
|
||||
external_data {
|
||||
key: "location",
|
||||
value: "data/tensor.data"
|
||||
}
|
||||
data_location: 1
|
||||
}
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "C"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 4
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
input: "B"
|
||||
output: "X"
|
||||
name: "multiply_node_1"
|
||||
op_type: "Mul"
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "C"
|
||||
output: "Y"
|
||||
name: "multiply_node_2"
|
||||
op_type: "Mul"
|
||||
}
|
||||
name: "test_graph"
|
||||
initializer {
|
||||
dims: 2
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "A"
|
||||
external_data {
|
||||
key: "location",
|
||||
value: "../data/tensor.data"
|
||||
}
|
||||
data_location: 1
|
||||
}
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "C"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 4
|
||||
}
|
@ -0,0 +1,112 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
|
||||
#include <ie_blob.h>
|
||||
#include <ie_core.hpp>
|
||||
#include <file_utils.h>
|
||||
#include <streambuf>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
|
||||
TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromFile) {
|
||||
InferenceEngine::Core ie;
|
||||
auto cnnNetwork = ie.ReadNetwork(std::string(ONNX_TEST_MODELS) + "onnx_external_data.prototxt", "");
|
||||
auto function = cnnNetwork.getFunction();
|
||||
|
||||
int count_additions = 0;
|
||||
int count_constants = 0;
|
||||
int count_parameters = 0;
|
||||
|
||||
std::shared_ptr<ngraph::Node> external_data_node;
|
||||
for (auto op : function->get_ops()) {
|
||||
const auto op_type = std::string(op->get_type_name());
|
||||
count_additions += (op_type == "Add" ? 1 : 0);
|
||||
count_parameters += (op_type == "Parameter" ? 1 : 0);
|
||||
if (op_type == "Constant") {
|
||||
count_constants += 1;
|
||||
external_data_node = op;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(function->get_output_size(), 1);
|
||||
ASSERT_EQ(std::string(function->get_output_op(0)->get_type_name()), "Result");
|
||||
ASSERT_EQ(function->get_output_element_type(0), ngraph::element::f32);
|
||||
ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({2, 2}));
|
||||
ASSERT_EQ(count_additions, 2);
|
||||
ASSERT_EQ(count_constants, 1);
|
||||
ASSERT_EQ(count_parameters, 2);
|
||||
|
||||
const auto external_data_node_const = ngraph::as_type_ptr<ngraph::op::Constant>(external_data_node);
|
||||
ASSERT_TRUE(external_data_node_const->get_vector<float>() == (std::vector<float>{1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromStringException) {
|
||||
InferenceEngine::Core ie;
|
||||
const auto path = std::string(ONNX_TEST_MODELS) + "onnx_external_data.prototxt";
|
||||
InferenceEngine::Blob::CPtr weights; //not used
|
||||
std::ifstream stream(path, std::ios::binary);
|
||||
std::string modelAsString((std::istreambuf_iterator<char>(stream)), std::istreambuf_iterator<char>());
|
||||
stream.close();
|
||||
try {
|
||||
auto cnnNetwork = ie.ReadNetwork(modelAsString, weights);
|
||||
}
|
||||
catch(const ngraph::ngraph_error& e) {
|
||||
EXPECT_PRED_FORMAT2(
|
||||
testing::IsSubstring,
|
||||
std::string("invalid external data:"),
|
||||
e.what());
|
||||
|
||||
EXPECT_PRED_FORMAT2(
|
||||
testing::IsSubstring,
|
||||
std::string("data/tensor.data, offset: 0, data_lenght: 0, sha1_digest: 0)"),
|
||||
e.what());
|
||||
}
|
||||
catch(...) {
|
||||
FAIL() << "Reading network failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromWstringNamedFile) {
|
||||
InferenceEngine::Core ie;
|
||||
std::string win_dir_path = ONNX_TEST_MODELS;
|
||||
std::replace(win_dir_path.begin(), win_dir_path.end(), '/', '\\');
|
||||
const std::wstring unicode_win_dir_path = FileUtils::multiByteCharToWString(win_dir_path.c_str());
|
||||
const std::wstring path = unicode_win_dir_path + L"АБВГДЕЁЖЗИЙ\\ひらがな日本語.prototxt";
|
||||
|
||||
auto cnnNetwork = ie.ReadNetwork(path, L"");
|
||||
auto function = cnnNetwork.getFunction();
|
||||
|
||||
int count_multiply = 0;
|
||||
int count_constants = 0;
|
||||
int count_parameters = 0;
|
||||
|
||||
std::shared_ptr<ngraph::Node> external_data_node;
|
||||
for (auto op : function->get_ops()) {
|
||||
const auto op_type = std::string(op->get_type_name());
|
||||
count_multiply += (op_type == "Multiply" ? 1 : 0);
|
||||
count_parameters += (op_type == "Parameter" ? 1 : 0);
|
||||
if (op_type == "Constant") {
|
||||
count_constants += 1;
|
||||
external_data_node = op;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(function->get_output_size(), 1);
|
||||
ASSERT_EQ(std::string(function->get_output_op(0)->get_type_name()), "Result");
|
||||
ASSERT_EQ(function->get_output_element_type(0), ngraph::element::f32);
|
||||
ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({2, 2}));
|
||||
ASSERT_EQ(count_multiply, 2);
|
||||
ASSERT_EQ(count_constants, 1);
|
||||
ASSERT_EQ(count_parameters, 2);
|
||||
|
||||
const auto external_data_node_const = ngraph::as_type_ptr<ngraph::op::Constant>(external_data_node);
|
||||
ASSERT_TRUE(external_data_node_const->get_vector<float>() == (std::vector<float>{1, 2, 3, 4}));
|
||||
}
|
||||
#endif
|
@ -63,5 +63,19 @@ namespace ngraph
|
||||
std::function<void(const std::string& file, bool is_dir)> func,
|
||||
bool recurse = false,
|
||||
bool include_links = false);
|
||||
|
||||
/// \brief Change Linux-style path ('/') to Windows-style ('\\')
|
||||
/// \param path The path to change file separator
|
||||
NGRAPH_API void convert_path_win_style(std::string& path);
|
||||
|
||||
/// \brief Conversion from wide character string to a single-byte chain.
|
||||
/// \param wstr A wide-char string
|
||||
/// \return A multi-byte string
|
||||
NGRAPH_API std::string wstring_to_string(const std::wstring& wstr);
|
||||
|
||||
/// \brief Conversion from single-byte chain to wide character string.
|
||||
/// \param str A null-terminated string
|
||||
/// \return A wide-char string
|
||||
NGRAPH_API std::wstring multi_byte_char_to_wstring(const char* str);
|
||||
}
|
||||
}
|
||||
|
@ -30,3 +30,14 @@
|
||||
#else
|
||||
#define NGRAPH_API NGRAPH_HELPER_DLL_IMPORT
|
||||
#endif // ngraph_EXPORTS
|
||||
|
||||
#ifndef ENABLE_UNICODE_PATH_SUPPORT
|
||||
#ifdef _WIN32
|
||||
#if defined __INTEL_COMPILER || defined _MSC_VER
|
||||
#define ENABLE_UNICODE_PATH_SUPPORT
|
||||
#endif
|
||||
#elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || \
|
||||
defined(__clang__)
|
||||
#define ENABLE_UNICODE_PATH_SUPPORT
|
||||
#endif
|
||||
#endif
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
@ -43,6 +44,10 @@
|
||||
#else
|
||||
#define RMDIR(a) rmdir(a)
|
||||
#define RMFILE(a) remove(a)
|
||||
#ifdef ENABLE_UNICODE_PATH_SUPPORT
|
||||
#include <codecvt>
|
||||
#include <locale>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
@ -77,10 +82,19 @@ string file_util::get_file_ext(const string& s)
|
||||
string file_util::get_directory(const string& s)
|
||||
{
|
||||
string rc = s;
|
||||
// Linux-style separator
|
||||
auto pos = s.find_last_of('/');
|
||||
if (pos != string::npos)
|
||||
{
|
||||
rc = s.substr(0, pos);
|
||||
return rc;
|
||||
}
|
||||
// Windows-style separator
|
||||
pos = s.find_last_of('\\');
|
||||
if (pos != string::npos)
|
||||
{
|
||||
rc = s.substr(0, pos);
|
||||
return rc;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
@ -240,3 +254,42 @@ void file_util::iterate_files(const string& path,
|
||||
func(f, true);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_API void file_util::convert_path_win_style(std::string& path)
|
||||
{
|
||||
std::replace(path.begin(), path.end(), '/', '\\');
|
||||
}
|
||||
|
||||
#ifdef ENABLE_UNICODE_PATH_SUPPORT
|
||||
|
||||
std::string file_util::wstring_to_string(const std::wstring& wstr)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
int size_needed =
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); // NOLINT
|
||||
std::string strTo(size_needed, 0);
|
||||
WideCharToMultiByte(
|
||||
CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); // NOLINT
|
||||
return strTo;
|
||||
#else
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> wstring_decoder;
|
||||
return wstring_decoder.to_bytes(wstr);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::wstring file_util::multi_byte_char_to_wstring(const char* str)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
int strSize = static_cast<int>(std::strlen(str));
|
||||
int size_needed = MultiByteToWideChar(CP_UTF8, 0, str, strSize, NULL, 0);
|
||||
std::wstring wstrTo(size_needed, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str, strSize, &wstrTo[0], size_needed);
|
||||
return wstrTo;
|
||||
#else
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> wstring_encoder;
|
||||
std::wstring result = wstring_encoder.from_bytes(str);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // ENABLE_UNICODE_PATH_SUPPORT
|
||||
|
@ -33,7 +33,8 @@ namespace ngraph
|
||||
/// \brief Load external data from tensor passed to constructor
|
||||
///
|
||||
/// \note If read data from external file fails,
|
||||
/// the invalid_external_data is thrown
|
||||
/// \note If reading data from external files fails,
|
||||
/// the invalid_external_data exception is thrown.
|
||||
///
|
||||
/// \return External binary data loaded into a std::string
|
||||
std::string load_external_data() const;
|
||||
|
@ -119,9 +119,13 @@ namespace ngraph
|
||||
{
|
||||
const auto external_data_relative_path =
|
||||
initializer_tensor.external_data(location_key_value_index).value();
|
||||
const auto external_data_full_path =
|
||||
auto external_data_full_path =
|
||||
file_util::path_join(model_dir_path, external_data_relative_path);
|
||||
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
file_util::convert_path_win_style(external_data_full_path);
|
||||
#endif
|
||||
|
||||
// Set full paths to the external file
|
||||
initializer_tensor.mutable_external_data(location_key_value_index)
|
||||
->set_value(external_data_full_path);
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "onnx_import/exceptions.hpp"
|
||||
#include "tensor_external_data.hpp"
|
||||
@ -44,7 +45,12 @@ namespace ngraph
|
||||
|
||||
std::string TensorExternalData::load_external_data() const
|
||||
{
|
||||
std::ifstream external_data_stream(m_data_location,
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
std::wstring path = file_util::multi_byte_char_to_wstring(m_data_location.c_str());
|
||||
#else
|
||||
std::string path = m_data_location;
|
||||
#endif
|
||||
std::ifstream external_data_stream(path,
|
||||
std::ios::binary | std::ios::in | std::ios::ate);
|
||||
if (external_data_stream.fail())
|
||||
throw error::invalid_external_data{*this};
|
||||
|
BIN
ngraph/python/tests/test_onnx/models/data/tensor.data
Normal file
BIN
ngraph/python/tests/test_onnx/models/data/tensor.data
Normal file
Binary file not shown.
77
ngraph/python/tests/test_onnx/models/external_data.prototxt
Normal file
77
ngraph/python/tests/test_onnx/models/external_data.prototxt
Normal file
@ -0,0 +1,77 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data_a"
|
||||
input: "data_b"
|
||||
input: "data_c"
|
||||
output: "result"
|
||||
op_type: "Mean"
|
||||
}
|
||||
name: "test_mean_example"
|
||||
initializer {
|
||||
dims: 3
|
||||
data_type: 1
|
||||
name: "data_c"
|
||||
external_data {
|
||||
key: "location",
|
||||
value: "data/tensor.data"
|
||||
}
|
||||
data_location: 1
|
||||
}
|
||||
input {
|
||||
name: "data_a"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "data_b"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "data_c"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "result"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 8
|
||||
}
|
41
ngraph/python/tests/test_onnx/test_onnx_external_data.py
Normal file
41
ngraph/python/tests/test_onnx/test_onnx_external_data.py
Normal file
@ -0,0 +1,41 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2017-2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import ngraph as ng
|
||||
from openvino.inference_engine import IECore
|
||||
|
||||
from tests.runtime import get_runtime
|
||||
|
||||
|
||||
def test_import_onnx_with_external_data():
|
||||
model_path = os.path.join(os.path.dirname(__file__), "models/external_data.prototxt")
|
||||
ie = IECore()
|
||||
ie_network = ie.read_network(model=model_path)
|
||||
|
||||
ng_function = ng.function_from_cnn(ie_network)
|
||||
|
||||
dtype = np.float32
|
||||
value_a = np.array([1.0, 3.0, 5.0], dtype=dtype)
|
||||
value_b = np.array([3.0, 5.0, 1.0], dtype=dtype)
|
||||
# third input [5.0, 1.0, 3.0] read from external file
|
||||
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(ng_function)
|
||||
result = computation(value_a, value_b)
|
||||
assert np.allclose(result, np.array([3.0, 3.0, 3.0], dtype=dtype))
|
Loading…
Reference in New Issue
Block a user