[ONNX] Fix external weights loading for the current dir path case (#16124)

This commit is contained in:
Mateusz Bencer
2023-03-08 12:12:23 +01:00
committed by GitHub
parent 0786a963ab
commit 50b76873e2
2 changed files with 29 additions and 5 deletions

View File

@@ -6,12 +6,31 @@ import os
import numpy as np
from openvino.runtime import Core
import shutil
import pytest
from tests.runtime import get_runtime
def test_import_onnx_with_external_data():
model_path = os.path.join(os.path.dirname(__file__), "models/external_data.onnx")
external_data_model_current_folder_path = "external_data.onnx"
external_data_model_full_path = os.path.join(os.path.dirname(__file__), "models", external_data_model_current_folder_path)
external_data_current_folder_path = "data/tensor.data"
external_data_full_path = os.path.join(os.path.dirname(__file__), "models", external_data_current_folder_path)
def setup_module():
shutil.copyfile(external_data_model_full_path, external_data_model_current_folder_path)
os.mkdir("data")
shutil.copyfile(external_data_full_path, external_data_current_folder_path)
def teardown_module():
os.remove(external_data_model_current_folder_path)
os.remove(external_data_current_folder_path)
os.rmdir("data")
@pytest.mark.parametrize("model_path", [external_data_model_full_path, external_data_model_current_folder_path])
def test_import_onnx_with_external_data(model_path: str):
core = Core()
model = core.read_model(model=model_path)

View File

@@ -12,6 +12,7 @@
#include "ngraph/file_util.hpp"
#include "onnx_framework_node.hpp"
#include "onnx_import/core/null_node.hpp"
#include "openvino/util/file_util.hpp"
namespace ngraph {
namespace onnx_import {
@@ -91,7 +92,9 @@ std::shared_ptr<Function> import_onnx_model(std::shared_ptr<ONNX_NAMESPACE::Mode
ov::frontend::ExtensionHolder extensions) {
apply_transformations(*model_proto);
NGRAPH_SUPPRESS_DEPRECATED_START
Graph graph{file_util::get_directory(model_path), model_proto, std::move(extensions)};
Graph graph{file_util::get_directory(ov::util::get_absolute_file_path(model_path)),
model_proto,
std::move(extensions)};
NGRAPH_SUPPRESS_DEPRECATED_END
return graph.convert();
}
@@ -101,7 +104,9 @@ std::shared_ptr<Function> decode_to_framework_nodes(std::shared_ptr<ONNX_NAMESPA
ov::frontend::ExtensionHolder extensions) {
apply_transformations(*model_proto);
NGRAPH_SUPPRESS_DEPRECATED_START
auto graph = std::make_shared<Graph>(file_util::get_directory(model_path), model_proto, extensions);
auto graph = std::make_shared<Graph>(file_util::get_directory(ov::util::get_absolute_file_path(model_path)),
model_proto,
extensions);
NGRAPH_SUPPRESS_DEPRECATED_END
return graph->decode();
}