Add support for is_equal_data, get_source_tensor, get_target_tensor methods in ONNX FE API (#6991)

This commit is contained in:
Mateusz Bencer 2021-08-17 13:14:11 +02:00 committed by GitHub
parent 74dd61b3b3
commit 9be53225f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 196 additions and 46 deletions

View File

@ -72,25 +72,26 @@ int onnx_editor::EdgeMapper::get_node_output_idx(int node_index, const std::stri
return (out_port_idx - std::begin(node_outputs));
}
int onnx_editor::EdgeMapper::get_node_input_idx(int node_index, const std::string& input_name) const {
std::vector<int> onnx_editor::EdgeMapper::get_node_input_indexes(int node_index, const std::string& input_name) const {
NGRAPH_CHECK(node_index >= 0 && node_index < static_cast<int>(m_node_inputs.size()),
"Node with index: ",
std::to_string(node_index),
"is out of scope outputs list");
const auto& node_inputs = m_node_inputs[node_index];
const auto matched_inputs = std::count(std::begin(node_inputs), std::end(node_inputs), input_name);
if (matched_inputs == 0) {
std::vector<int> node_inputs_indexes;
int index = 0;
for (const auto& in : node_inputs) {
if (in == input_name) {
node_inputs_indexes.push_back(index);
}
++index;
}
if (node_inputs_indexes.size() == 0) {
throw ngraph_error("Node with index: " + std::to_string(node_index) +
" has not input with name: " + input_name);
}
if (matched_inputs > 1) // more indexes with the same name
{
throw ngraph_error("Node with index: " + std::to_string(node_index) + " has more than one inputs with name: " +
input_name + ". You should use port indexes to distinguish them.");
}
const auto in_port_idx = std::find(std::begin(node_inputs), std::end(node_inputs), input_name);
return (in_port_idx - std::begin(node_inputs));
return node_inputs_indexes;
}
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const {
@ -131,8 +132,14 @@ InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const
return InputEdge{node_index, in.m_input_index, in.m_new_input_name};
}
if (!in.m_input_name.empty()) {
const auto input_idx = get_node_input_idx(node_index, in.m_input_name);
return InputEdge{node_index, input_idx, in.m_new_input_name};
const auto input_indexes = get_node_input_indexes(node_index, in.m_input_name);
if (input_indexes.size() > 1) // more indexes with the same name
{
throw ngraph_error("Node with index: " + std::to_string(node_index) +
" has more than one inputs with name: " + in.m_input_name +
". You should use port indexes to distinguish them.");
}
return InputEdge{node_index, input_indexes[0], in.m_new_input_name};
} else {
throw ngraph_error("Not enough information to determine input edge");
}
@ -186,14 +193,13 @@ OutputEdge onnx_editor::EdgeMapper::find_output_edge(const std::string& output_n
std::vector<InputEdge> onnx_editor::EdgeMapper::find_output_consumers(const std::string& output_name) const {
const auto matched_nodes_range = m_output_consumers_index.equal_range(output_name);
std::vector<InputEdge> input_edges;
std::transform(matched_nodes_range.first,
matched_nodes_range.second,
std::back_inserter(input_edges),
[&output_name, this](const std::pair<std::string, int>& iter) {
const auto node_idx = iter.second;
const auto port_idx = this->get_node_input_idx(node_idx, output_name);
return InputEdge{node_idx, port_idx};
});
for (auto it = matched_nodes_range.first; it != matched_nodes_range.second; ++it) {
const auto node_idx = it->second;
const auto port_indexes = get_node_input_indexes(node_idx, output_name);
for (const auto& idx : port_indexes) {
input_edges.push_back(InputEdge{node_idx, idx});
}
}
return input_edges;
}
@ -211,7 +217,7 @@ bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) co
return false;
}
std::string onnx_editor::EdgeMapper::get_input_port_name(const InputEdge& edge) const {
std::string onnx_editor::EdgeMapper::get_source_tensor_name(const InputEdge& edge) const {
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_inputs.size()) && edge.m_port_idx >= 0 &&
edge.m_port_idx < static_cast<int>(m_node_inputs[edge.m_node_idx].size())) {
return m_node_inputs[edge.m_node_idx][edge.m_port_idx];
@ -219,7 +225,7 @@ std::string onnx_editor::EdgeMapper::get_input_port_name(const InputEdge& edge)
return "";
}
std::string onnx_editor::EdgeMapper::get_output_port_name(const OutputEdge& edge) const {
std::string onnx_editor::EdgeMapper::get_target_tensor_name(const OutputEdge& edge) const {
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_outputs.size()) && edge.m_port_idx >= 0 &&
edge.m_port_idx < static_cast<int>(m_node_outputs[edge.m_node_idx].size())) {
return m_node_outputs[edge.m_node_idx][edge.m_port_idx];

View File

@ -98,22 +98,23 @@ public:
///
bool is_correct_tensor_name(const std::string& name) const;
/// \brief Get name of input port indicated by the input edge.
/// \brief Get name of the tensor which is the source of the input edge.
///
/// \note Empty string is returned if the port name is not found.
/// \note Empty string is returned if the tensor name is not found.
///
std::string get_input_port_name(const InputEdge& edge) const;
std::string get_source_tensor_name(const InputEdge& edge) const;
/// \brief Get name of output port indicated by the input edge.
/// \brief Get name of the tensor which is the target of the output edge.
///
/// \note Empty string is returned if the port name is not found.
/// \note Empty string is returned if the tensor name is not found.
///
std::string get_output_port_name(const OutputEdge& edge) const;
std::string get_target_tensor_name(const OutputEdge& edge) const;
private:
std::vector<int> find_node_indexes(const std::string& node_name, const std::string& output_name) const;
int get_node_input_idx(int node_index, const std::string& input_name) const;
// note: a single node can have more than one inputs with the same name
std::vector<int> get_node_input_indexes(int node_index, const std::string& input_name) const;
int get_node_output_idx(int node_index, const std::string& output_name) const;
std::vector<std::vector<std::string>> m_node_inputs;

View File

@ -325,25 +325,33 @@ std::vector<std::string> onnx_editor::ONNXModelEditor::model_outputs() const {
return outputs;
}
bool onnx_editor::ONNXModelEditor::is_input(const InputEdge& edge) const {
std::string onnx_editor::ONNXModelEditor::get_source_tensor_name(const InputEdge& edge) const {
update_mapper_if_needed();
const auto& port_name = m_pimpl->m_edge_mapper.get_input_port_name(edge);
if (port_name.empty()) {
return m_pimpl->m_edge_mapper.get_source_tensor_name(edge);
}
bool onnx_editor::ONNXModelEditor::is_input(const InputEdge& edge) const {
const auto& tensor_name = get_source_tensor_name(edge);
if (tensor_name.empty()) {
return false;
} else {
const auto& inputs = model_inputs();
return std::count(std::begin(inputs), std::end(inputs), port_name) > 0;
return std::count(std::begin(inputs), std::end(inputs), tensor_name) > 0;
}
}
bool onnx_editor::ONNXModelEditor::is_output(const OutputEdge& edge) const {
std::string onnx_editor::ONNXModelEditor::get_target_tensor_name(const OutputEdge& edge) const {
update_mapper_if_needed();
const auto& port_name = m_pimpl->m_edge_mapper.get_output_port_name(edge);
if (port_name.empty()) {
return m_pimpl->m_edge_mapper.get_target_tensor_name(edge);
}
bool onnx_editor::ONNXModelEditor::is_output(const OutputEdge& edge) const {
const auto& tensor_name = get_target_tensor_name(edge);
if (tensor_name.empty()) {
return false;
} else {
const auto& outputs = model_outputs();
return std::count(std::begin(outputs), std::end(outputs), port_name) > 0;
return std::count(std::begin(outputs), std::end(outputs), tensor_name) > 0;
}
}

View File

@ -100,9 +100,21 @@ public:
/// instance of the model editor.
std::vector<std::string> model_outputs() const;
/// \brief Get name of the tensor which is the source of the input edge.
///
/// \note Empty string is returned if the tensor name is not found.
///
std::string get_source_tensor_name(const InputEdge& edge) const;
/// \brief Returns true if input edge is input of the model. Otherwise false.
bool is_input(const InputEdge& edge) const;
/// \brief Get name of the tensor which is the target of the output edge.
///
/// \note Empty string is returned if the tensor name is not found.
///
std::string get_target_tensor_name(const OutputEdge& edge) const;
/// \brief Returns true if output edge is input of the model. Otherwise false.
bool is_output(const OutputEdge& edge) const;

View File

@ -34,6 +34,15 @@ bool PlaceInputEdgeONNX::is_equal(Place::Ptr another) const {
return false;
}
bool PlaceInputEdgeONNX::is_equal_data(Place::Ptr another) const {
return get_source_tensor()->is_equal_data(another);
}
Place::Ptr PlaceInputEdgeONNX::get_source_tensor() const {
const auto tensor_name = m_editor->get_source_tensor_name(m_edge);
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
}
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge},
@ -59,6 +68,15 @@ bool PlaceOutputEdgeONNX::is_equal(Place::Ptr another) const {
return false;
}
bool PlaceOutputEdgeONNX::is_equal_data(Place::Ptr another) const {
return get_target_tensor()->is_equal_data(another);
}
Place::Ptr PlaceOutputEdgeONNX::get_target_tensor() const {
const auto tensor_name = m_editor->get_target_tensor_name(m_edge);
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
}
PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_name(name),
m_editor(editor) {}
@ -68,6 +86,8 @@ std::vector<std::string> PlaceTensorONNX::get_names() const {
}
Place::Ptr PlaceTensorONNX::get_producing_port() const {
FRONT_END_GENERAL_CHECK(!is_input(),
"Tensor: " + m_name + " is an input of the model and doesn't have producing port.");
return std::make_shared<PlaceOutputEdgeONNX>(m_editor->find_output_edge(m_name), m_editor);
}
@ -102,3 +122,14 @@ bool PlaceTensorONNX::is_equal(Place::Ptr another) const {
}
return false;
}
bool PlaceTensorONNX::is_equal_data(Place::Ptr another) const {
const auto consuming_ports = get_consuming_ports();
const auto eq_to_consuming_port = [&consuming_ports](const Ptr& another) {
return std::any_of(consuming_ports.begin(), consuming_ports.end(), [&another](const Ptr& place) {
return place->is_equal(another);
});
};
return is_equal(another) || (is_input() ? false : get_producing_port()->is_equal(another)) ||
eq_to_consuming_port(another);
}

View File

@ -22,6 +22,10 @@ public:
bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override;
Place::Ptr get_source_tensor() const override;
private:
onnx_editor::InputEdge m_edge;
const std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
@ -39,6 +43,10 @@ public:
bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override;
Place::Ptr get_target_tensor() const override;
private:
onnx_editor::OutputEdge m_edge;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
@ -62,6 +70,8 @@ public:
bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override;
private:
std::string m_name;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;

View File

@ -45,7 +45,8 @@ def create_test_onnx_models():
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph
input_tensors = [
@ -56,7 +57,8 @@ def create_test_onnx_models():
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add], "test_graph", input_tensors, output_tensors)
models["extract_subgraph.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["extract_subgraph.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph 2
input_tensors = [
@ -69,7 +71,8 @@ def create_test_onnx_models():
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["extract_subgraph_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph 3
input_tensors = [
@ -82,7 +85,8 @@ def create_test_onnx_models():
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
outputs=["out1", "out2"], name="split1", axis=0)
graph = make_graph([expected_split], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_3.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["extract_subgraph_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph 4
input_tensors = [
@ -100,7 +104,8 @@ def create_test_onnx_models():
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
outputs=["out4"])
graph = make_graph([expected_split, expected_mul], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for test_override_all_outputs
input_tensors = [
@ -113,7 +118,8 @@ def create_test_onnx_models():
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
models["test_override_all_outputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["test_override_all_outputs.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for test_override_all_outputs 2
input_tensors = [
@ -124,7 +130,8 @@ def create_test_onnx_models():
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, mul], "test_graph", input_tensors, output_tensors)
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for test_override_all_inputs
input_tensors = [
@ -144,7 +151,8 @@ def create_test_onnx_models():
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
outputs=["out4"])
graph = make_graph([expected_split, relu, expected_mul], "test_graph", input_tensors, output_tensors)
models["test_override_all_inputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["test_override_all_inputs.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# test partial shape
input_tensors = [
@ -159,7 +167,8 @@ def create_test_onnx_models():
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (8, 16)),
]
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer")
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
return models
@ -530,6 +539,45 @@ def test_is_equal():
assert not place8.is_equal(place2)
def test_is_equal_data():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
assert place1.is_equal_data(place1)
place2 = model.get_place_by_tensor_name(tensorName="add_out")
assert place2.is_equal_data(place2)
place3 = model.get_place_by_tensor_name(tensorName="in2")
assert not place1.is_equal_data(place3)
assert not place2.is_equal_data(place1)
place4 = place2.get_producing_port()
assert place2.is_equal_data(place4)
place5 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
assert place2.is_equal_data(place5)
assert place4.is_equal_data(place5)
place6 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
assert place6.is_equal_data(place5)
place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
assert place7.is_equal_data(place7)
place8 = model.get_place_by_tensor_name(tensorName="out1")
place9 = model.get_place_by_tensor_name(tensorName="out2")
place10 = place8.get_producing_port()
assert not place8.is_equal_data(place9)
assert not place9.is_equal_data(place10)
assert place8.is_equal_data(place10)
def test_get_place_by_tensor_name():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)

View File

@ -1019,6 +1019,16 @@ NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers_empty_result) {
EXPECT_EQ(output_consumers.size(), 0);
}
NGRAPH_TEST(onnx_editor, editor_api_inputs_with_the_same_name) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")};
std::vector<InputEdge> output_consumers = editor.find_output_consumers("X");
EXPECT_EQ(output_consumers[0].m_node_idx, 1);
EXPECT_EQ(output_consumers[0].m_port_idx, 0);
EXPECT_EQ(output_consumers[1].m_node_idx, 1);
EXPECT_EQ(output_consumers[1].m_port_idx, 1);
}
NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
@ -1204,6 +1214,19 @@ NGRAPH_TEST(onnx_editor, cut_operator_with_no_schema) {
EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, get_source_tensor_name) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{0, 0}), "in1");
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{1, 0}), "relu1");
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{1, 1}), "in2");
const auto edge1 = editor.find_input_edge(EditorOutput{"conv1"}, 1);
EXPECT_EQ(editor.get_source_tensor_name(edge1), "in4");
const auto edge2 = editor.find_input_edge(EditorOutput{"split2"}, 0);
EXPECT_EQ(editor.get_source_tensor_name(edge2), "add2");
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{999, 999}), "");
}
NGRAPH_TEST(onnx_editor, is_model_input) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
@ -1221,6 +1244,17 @@ NGRAPH_TEST(onnx_editor, is_model_input) {
EXPECT_FALSE(editor.is_input(edge3));
}
NGRAPH_TEST(onnx_editor, get_target_tensor_name) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{0, 0}), "relu1");
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{1, 0}), "add1");
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{4, 0}), "mul2");
const auto edge1 = editor.find_output_edge("split1");
EXPECT_EQ(editor.get_target_tensor_name(edge1), "split1");
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{999, 999}), "");
}
NGRAPH_TEST(onnx_editor, is_model_output) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};