[ONNX FE] Add support for partial shapes with boundaries (#9321)

This commit is contained in:
Tomasz Jankowski 2021-12-21 16:38:15 +01:00 committed by GitHub
parent 6f437fc1bd
commit dcfaf424a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 1 deletions

View File

@ -5,7 +5,7 @@ import os
import onnx
import pytest
from onnx.helper import make_graph, make_model, make_tensor_value_info
from openvino.runtime import PartialShape
from openvino.runtime import Dimension, PartialShape
from openvino.frontend import FrontEndManager
@ -1221,6 +1221,50 @@ def test_set_input_partial_shape_using_input_edge():
assert ov_model.output("out4").get_partial_shape() == PartialShape([10, 10])
def test_set_partial_shape_with_range():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("input_model.onnx")
input1 = model.get_place_by_tensor_name("in1")
ranged_shape = PartialShape([Dimension(1, 4), Dimension(2)])
model.set_partial_shape(input1, ranged_shape)
ov_model = fe.convert(model)
assert ov_model.input("in1").get_partial_shape() == ranged_shape
def test_set_partial_shape_with_range_and_cut_it_off():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("input_model.onnx")
input1 = model.get_place_by_tensor_name("in1")
ranged_shape = PartialShape([Dimension(1, 4), Dimension(2)])
model.set_partial_shape(input1, ranged_shape)
add_out = model.get_place_by_tensor_name("add_out")
model.extract_subgraph(inputs=[add_out], outputs=[])
ov_model = fe.convert(model)
for input in ov_model.inputs:
assert input.get_partial_shape() != ranged_shape
def test_set_partial_shape_with_range_and_rename_it():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("input_model.onnx")
input1 = model.get_place_by_tensor_name("in1")
ranged_shape = PartialShape([Dimension(1, 4), Dimension(2)])
model.set_partial_shape(input1, ranged_shape)
model.set_name_for_tensor(input1, "new_in1")
ov_model = fe.convert(model)
assert ov_model.input("new_in1").get_partial_shape() == ranged_shape
def test_get_partial_shape_using_input_edge():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)

View File

@ -100,6 +100,11 @@ void InputModel::set_name_for_tensor(const ov::frontend::Place::Ptr& tensor, con
m_additional_tensor_names[new_name] = m_additional_tensor_names[original_name];
m_additional_tensor_names.erase(original_name);
}
if (m_inputs_to_reshape.count(original_name) > 0) {
m_inputs_to_reshape[new_name] = m_inputs_to_reshape[original_name];
m_inputs_to_reshape.erase(original_name);
}
}
void InputModel::set_name_for_operation(const ov::frontend::Place::Ptr& operation, const std::string& new_name) {
@ -157,6 +162,9 @@ void InputModel::set_partial_shape(const ov::frontend::Place::Ptr& place, const
}
m_editor->set_input_shapes({{input_name, shape}});
if (shape.get_min_shape() != shape.get_max_shape())
m_inputs_to_reshape[input_name] = shape;
}
ngraph::PartialShape InputModel::get_partial_shape(const ov::frontend::Place::Ptr& place) const {
@ -193,6 +201,7 @@ std::shared_ptr<Model> InputModel::decode() {
std::shared_ptr<Model> InputModel::convert() {
auto converted_model = m_editor->get_function();
add_tensor_names(converted_model);
reshape_model_inputs(converted_model);
return converted_model;
}
@ -296,3 +305,21 @@ void InputModel::add_tensor_names(std::shared_ptr<Model>& model) {
}
}
}
void InputModel::reshape_model_inputs(std::shared_ptr<Model>& model) {
const auto& inputs = model->inputs();
const auto is_input_name = [&inputs](const std::string& name) {
return std::find_if(std::begin(inputs), std::end(inputs), [&name](const OutputVector::value_type& input) {
return input.get_names().count(name) > 0;
}) != std::end(inputs);
};
// assure that names actually refer to model's inputs
std::map<std::string, ov::PartialShape> actual_inputs_to_reshape;
for (const auto& in : m_inputs_to_reshape)
if (is_input_name(in.first))
actual_inputs_to_reshape.insert(in);
if (!actual_inputs_to_reshape.empty())
model->reshape(actual_inputs_to_reshape);
}

View File

@ -69,6 +69,9 @@ private:
std::unordered_map<std::string, std::unordered_set<std::string>> m_additional_tensor_names;
void add_tensor_names(std::shared_ptr<Model>& model);
std::unordered_map<std::string, ov::PartialShape> m_inputs_to_reshape;
void reshape_model_inputs(std::shared_ptr<Model>& model);
};
} // namespace onnx