Fix "Unexpected number of outputs after override_all_outputs" (#9454)
This commit is contained in:
parent
f255c195c5
commit
e89db1c6de
@ -49,6 +49,16 @@ from openvino.frontend import FrontEndManager
|
||||
# | |
|
||||
# out1 out2
|
||||
#
|
||||
#
|
||||
# ------Test input model 3------
|
||||
# in1 in2
|
||||
# | / \
|
||||
# +--------+ +------+
|
||||
# | Add | | Relu |
|
||||
# +--------+ +------+
|
||||
# | |
|
||||
# out1 out2
|
||||
#
|
||||
def create_test_onnx_models():
|
||||
models = {}
|
||||
# Input model 1
|
||||
@ -91,6 +101,23 @@ def create_test_onnx_models():
|
||||
models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Input model 3
|
||||
add_2 = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["out1"], name="onnx_add_op")
|
||||
relu_2 = onnx.helper.make_node("Relu", inputs=["in2"], outputs=["out2"])
|
||||
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add_2, relu_2], "test_graph_3", input_tensors, output_tensors)
|
||||
models["input_model_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
@ -188,6 +215,19 @@ def create_test_onnx_models():
|
||||
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for test_override_all_outputs 3
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add_2], "test_graph_3", input_tensors, output_tensors)
|
||||
models["test_override_all_outputs_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for test_override_all_inputs
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
@ -594,6 +634,50 @@ def test_override_all_outputs_2():
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_outputs_3():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model_3.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensor_name="out1")
|
||||
place2 = model.get_place_by_tensor_name(tensor_name="out1")
|
||||
model.override_all_outputs(outputs=[place1, place2])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_override_all_outputs_3.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_outputs_invalid_place():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model_3.onnx")
|
||||
assert model
|
||||
|
||||
model2 = fe.load("input_model.onnx")
|
||||
assert model2
|
||||
invalid_place = model2.get_place_by_tensor_name(tensor_name="out3")
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensor_name="out1")
|
||||
place2 = model.get_place_by_tensor_name(tensor_name="out1")
|
||||
model.override_all_outputs(outputs=[place1, place2, invalid_place])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_override_all_outputs_3.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_inputs():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
@ -618,26 +702,31 @@ def test_override_all_inputs():
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_inputs_exceptions():
|
||||
def test_override_all_inputs_invalid_place():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
model = fe.load("input_model_3.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensor_name="in1")
|
||||
place2 = model.get_place_by_tensor_name(tensor_name="in2")
|
||||
place3 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
|
||||
place4 = model.get_place_by_tensor_name(tensor_name="in3")
|
||||
model2 = fe.load("input_model.onnx")
|
||||
assert model2
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
model.override_all_inputs(inputs=[place1, place2])
|
||||
assert "Unexpected number of inputs after override_all_inputs" in str(e)
|
||||
out3_tensor = model2.get_place_by_tensor_name(tensor_name="out3")
|
||||
invalid_place = out3_tensor.get_producing_operation().get_input_port(input_port_index=0)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
model.override_all_inputs(inputs=[place3, place4])
|
||||
assert "Unexpected number of inputs after override_all_inputs" in str(e)
|
||||
out1_tensor = model.get_place_by_tensor_name(tensor_name="out1")
|
||||
place1 = out1_tensor.get_producing_operation().get_input_port(input_port_index=0)
|
||||
place2 = out1_tensor.get_producing_operation().get_input_port(input_port_index=1)
|
||||
model.override_all_inputs(inputs=[place1, place2, invalid_place])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("input_model_3.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_is_input_output():
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <openvino/frontend/exception.hpp>
|
||||
#include <openvino/util/file_util.hpp>
|
||||
|
||||
#include "ngraph/log.hpp"
|
||||
#include "place.hpp"
|
||||
|
||||
using namespace ov;
|
||||
@ -202,28 +203,90 @@ std::shared_ptr<Model> InputModel::convert() {
|
||||
}
|
||||
|
||||
// Editor features
|
||||
bool InputModel::is_correct_place(const ov::frontend::Place::Ptr& place) const {
|
||||
if (const auto tensor = std::dynamic_pointer_cast<PlaceTensor>(place)) {
|
||||
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
|
||||
}
|
||||
if (const auto op = std::dynamic_pointer_cast<PlaceOp>(place)) {
|
||||
return m_editor->is_correct_and_unambiguous_node(op->get_editor_node());
|
||||
}
|
||||
if (const auto input_edge = std::dynamic_pointer_cast<PlaceInputEdge>(place)) {
|
||||
if (auto tensor = std::dynamic_pointer_cast<PlaceTensor>(input_edge->get_source_tensor())) {
|
||||
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
|
||||
}
|
||||
}
|
||||
if (const auto output_edge = std::dynamic_pointer_cast<PlaceOutputEdge>(place)) {
|
||||
if (auto tensor = std::dynamic_pointer_cast<PlaceTensor>(output_edge->get_target_tensor())) {
|
||||
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void InputModel::override_all_outputs(const std::vector<ov::frontend::Place::Ptr>& outputs) {
|
||||
extract_subgraph({}, outputs);
|
||||
NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(),
|
||||
"Unexpected number of outputs after override_all_outputs");
|
||||
NGRAPH_CHECK(std::all_of(std::begin(outputs),
|
||||
std::end(outputs),
|
||||
std::vector<Place::Ptr> expected_valid_outputs;
|
||||
for (const auto& output : outputs) {
|
||||
bool is_correct = is_correct_place(output);
|
||||
if (!is_correct)
|
||||
NGRAPH_WARN << "Name " << output->get_names().at(0)
|
||||
<< " of output node is not a correct node name. Ignoring this parameter.";
|
||||
else
|
||||
expected_valid_outputs.push_back(output);
|
||||
}
|
||||
|
||||
extract_subgraph({}, expected_valid_outputs);
|
||||
|
||||
NGRAPH_CHECK(std::all_of(std::begin(expected_valid_outputs),
|
||||
std::end(expected_valid_outputs),
|
||||
[](const ov::frontend::Place::Ptr& place) {
|
||||
return place->is_output();
|
||||
}),
|
||||
"Not all provided arguments of override_all_outputs are new outputs of the model");
|
||||
|
||||
const auto current_outputs = get_outputs();
|
||||
NGRAPH_CHECK(std::all_of(std::begin(current_outputs),
|
||||
std::end(current_outputs),
|
||||
[&](const Place::Ptr& current_out) {
|
||||
return std::find_if(std::begin(expected_valid_outputs),
|
||||
std::end(expected_valid_outputs),
|
||||
[&](const Place::Ptr& expected_out) {
|
||||
return expected_out->is_equal(current_out);
|
||||
}) != std::end(current_outputs);
|
||||
}),
|
||||
"Some other than expected outputs were created during override_all_outputs");
|
||||
}
|
||||
|
||||
void InputModel::override_all_inputs(const std::vector<ov::frontend::Place::Ptr>& inputs) {
|
||||
std::vector<Place::Ptr> expected_valid_inputs;
|
||||
for (const auto& input : inputs) {
|
||||
bool is_correct = is_correct_place(input);
|
||||
if (!is_correct)
|
||||
NGRAPH_WARN << "Name " << input->get_names().at(0)
|
||||
<< " of input node is not a correct node. Ignoring this parameter.";
|
||||
else
|
||||
expected_valid_inputs.push_back(input);
|
||||
}
|
||||
|
||||
const auto outputs_before_extraction = m_editor->model_outputs();
|
||||
extract_subgraph({inputs}, {});
|
||||
extract_subgraph({expected_valid_inputs}, {});
|
||||
|
||||
NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction),
|
||||
std::end(outputs_before_extraction),
|
||||
std::begin(m_editor->model_outputs())),
|
||||
"All outputs should be preserved after override_all_inputs. Provided inputs does "
|
||||
"not satisfy all outputs");
|
||||
NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(),
|
||||
"Unexpected number of inputs after override_all_inputs");
|
||||
|
||||
const auto current_inputs = get_inputs();
|
||||
NGRAPH_CHECK(std::all_of(std::begin(current_inputs),
|
||||
std::end(current_inputs),
|
||||
[&](const Place::Ptr& current_in) {
|
||||
return std::find_if(std::begin(expected_valid_inputs),
|
||||
std::end(expected_valid_inputs),
|
||||
[&](const Place::Ptr& expected_in) {
|
||||
return expected_in->is_equal(current_in);
|
||||
}) != std::end(current_inputs);
|
||||
}),
|
||||
"Some other than expected inputs were created during override_all_inputs");
|
||||
}
|
||||
|
||||
void InputModel::extract_subgraph(const std::vector<ov::frontend::Place::Ptr>& inputs,
|
||||
|
@ -78,6 +78,7 @@ public:
|
||||
|
||||
private:
|
||||
std::shared_ptr<ov::onnx_editor::ONNXModelEditor> m_editor;
|
||||
bool is_correct_place(const ov::frontend::Place::Ptr& place) const;
|
||||
|
||||
std::unordered_map<std::string, std::unordered_set<std::string>> m_additional_tensor_names;
|
||||
void add_tensor_names(std::shared_ptr<Model>& model);
|
||||
|
Loading…
Reference in New Issue
Block a user