Add support for is_equal_data, get_source_tensor, get_target_tensor methods in ONNX FE API (#6991)
This commit is contained in:
parent
74dd61b3b3
commit
9be53225f0
@ -72,25 +72,26 @@ int onnx_editor::EdgeMapper::get_node_output_idx(int node_index, const std::stri
|
||||
return (out_port_idx - std::begin(node_outputs));
|
||||
}
|
||||
|
||||
int onnx_editor::EdgeMapper::get_node_input_idx(int node_index, const std::string& input_name) const {
|
||||
std::vector<int> onnx_editor::EdgeMapper::get_node_input_indexes(int node_index, const std::string& input_name) const {
|
||||
NGRAPH_CHECK(node_index >= 0 && node_index < static_cast<int>(m_node_inputs.size()),
|
||||
"Node with index: ",
|
||||
std::to_string(node_index),
|
||||
"is out of scope outputs list");
|
||||
|
||||
const auto& node_inputs = m_node_inputs[node_index];
|
||||
const auto matched_inputs = std::count(std::begin(node_inputs), std::end(node_inputs), input_name);
|
||||
if (matched_inputs == 0) {
|
||||
std::vector<int> node_inputs_indexes;
|
||||
int index = 0;
|
||||
for (const auto& in : node_inputs) {
|
||||
if (in == input_name) {
|
||||
node_inputs_indexes.push_back(index);
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (node_inputs_indexes.size() == 0) {
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) +
|
||||
" has not input with name: " + input_name);
|
||||
}
|
||||
if (matched_inputs > 1) // more indexes with the same name
|
||||
{
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) + " has more than one inputs with name: " +
|
||||
input_name + ". You should use port indexes to distinguish them.");
|
||||
}
|
||||
const auto in_port_idx = std::find(std::begin(node_inputs), std::end(node_inputs), input_name);
|
||||
return (in_port_idx - std::begin(node_inputs));
|
||||
return node_inputs_indexes;
|
||||
}
|
||||
|
||||
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const {
|
||||
@ -131,8 +132,14 @@ InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const
|
||||
return InputEdge{node_index, in.m_input_index, in.m_new_input_name};
|
||||
}
|
||||
if (!in.m_input_name.empty()) {
|
||||
const auto input_idx = get_node_input_idx(node_index, in.m_input_name);
|
||||
return InputEdge{node_index, input_idx, in.m_new_input_name};
|
||||
const auto input_indexes = get_node_input_indexes(node_index, in.m_input_name);
|
||||
if (input_indexes.size() > 1) // more indexes with the same name
|
||||
{
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) +
|
||||
" has more than one inputs with name: " + in.m_input_name +
|
||||
". You should use port indexes to distinguish them.");
|
||||
}
|
||||
return InputEdge{node_index, input_indexes[0], in.m_new_input_name};
|
||||
} else {
|
||||
throw ngraph_error("Not enough information to determine input edge");
|
||||
}
|
||||
@ -186,14 +193,13 @@ OutputEdge onnx_editor::EdgeMapper::find_output_edge(const std::string& output_n
|
||||
std::vector<InputEdge> onnx_editor::EdgeMapper::find_output_consumers(const std::string& output_name) const {
|
||||
const auto matched_nodes_range = m_output_consumers_index.equal_range(output_name);
|
||||
std::vector<InputEdge> input_edges;
|
||||
std::transform(matched_nodes_range.first,
|
||||
matched_nodes_range.second,
|
||||
std::back_inserter(input_edges),
|
||||
[&output_name, this](const std::pair<std::string, int>& iter) {
|
||||
const auto node_idx = iter.second;
|
||||
const auto port_idx = this->get_node_input_idx(node_idx, output_name);
|
||||
return InputEdge{node_idx, port_idx};
|
||||
});
|
||||
for (auto it = matched_nodes_range.first; it != matched_nodes_range.second; ++it) {
|
||||
const auto node_idx = it->second;
|
||||
const auto port_indexes = get_node_input_indexes(node_idx, output_name);
|
||||
for (const auto& idx : port_indexes) {
|
||||
input_edges.push_back(InputEdge{node_idx, idx});
|
||||
}
|
||||
}
|
||||
return input_edges;
|
||||
}
|
||||
|
||||
@ -211,7 +217,7 @@ bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) co
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string onnx_editor::EdgeMapper::get_input_port_name(const InputEdge& edge) const {
|
||||
std::string onnx_editor::EdgeMapper::get_source_tensor_name(const InputEdge& edge) const {
|
||||
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_inputs.size()) && edge.m_port_idx >= 0 &&
|
||||
edge.m_port_idx < static_cast<int>(m_node_inputs[edge.m_node_idx].size())) {
|
||||
return m_node_inputs[edge.m_node_idx][edge.m_port_idx];
|
||||
@ -219,7 +225,7 @@ std::string onnx_editor::EdgeMapper::get_input_port_name(const InputEdge& edge)
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string onnx_editor::EdgeMapper::get_output_port_name(const OutputEdge& edge) const {
|
||||
std::string onnx_editor::EdgeMapper::get_target_tensor_name(const OutputEdge& edge) const {
|
||||
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_outputs.size()) && edge.m_port_idx >= 0 &&
|
||||
edge.m_port_idx < static_cast<int>(m_node_outputs[edge.m_node_idx].size())) {
|
||||
return m_node_outputs[edge.m_node_idx][edge.m_port_idx];
|
||||
|
@ -98,22 +98,23 @@ public:
|
||||
///
|
||||
bool is_correct_tensor_name(const std::string& name) const;
|
||||
|
||||
/// \brief Get name of input port indicated by the input edge.
|
||||
/// \brief Get name of the tensor which is the source of the input edge.
|
||||
///
|
||||
/// \note Empty string is returned if the port name is not found.
|
||||
/// \note Empty string is returned if the tensor name is not found.
|
||||
///
|
||||
std::string get_input_port_name(const InputEdge& edge) const;
|
||||
std::string get_source_tensor_name(const InputEdge& edge) const;
|
||||
|
||||
/// \brief Get name of output port indicated by the input edge.
|
||||
/// \brief Get name of the tensor which is the target of the output edge.
|
||||
///
|
||||
/// \note Empty string is returned if the port name is not found.
|
||||
/// \note Empty string is returned if the tensor name is not found.
|
||||
///
|
||||
std::string get_output_port_name(const OutputEdge& edge) const;
|
||||
std::string get_target_tensor_name(const OutputEdge& edge) const;
|
||||
|
||||
private:
|
||||
std::vector<int> find_node_indexes(const std::string& node_name, const std::string& output_name) const;
|
||||
|
||||
int get_node_input_idx(int node_index, const std::string& input_name) const;
|
||||
// note: a single node can have more than one inputs with the same name
|
||||
std::vector<int> get_node_input_indexes(int node_index, const std::string& input_name) const;
|
||||
int get_node_output_idx(int node_index, const std::string& output_name) const;
|
||||
|
||||
std::vector<std::vector<std::string>> m_node_inputs;
|
||||
|
@ -325,25 +325,33 @@ std::vector<std::string> onnx_editor::ONNXModelEditor::model_outputs() const {
|
||||
return outputs;
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_input(const InputEdge& edge) const {
|
||||
std::string onnx_editor::ONNXModelEditor::get_source_tensor_name(const InputEdge& edge) const {
|
||||
update_mapper_if_needed();
|
||||
const auto& port_name = m_pimpl->m_edge_mapper.get_input_port_name(edge);
|
||||
if (port_name.empty()) {
|
||||
return m_pimpl->m_edge_mapper.get_source_tensor_name(edge);
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_input(const InputEdge& edge) const {
|
||||
const auto& tensor_name = get_source_tensor_name(edge);
|
||||
if (tensor_name.empty()) {
|
||||
return false;
|
||||
} else {
|
||||
const auto& inputs = model_inputs();
|
||||
return std::count(std::begin(inputs), std::end(inputs), port_name) > 0;
|
||||
return std::count(std::begin(inputs), std::end(inputs), tensor_name) > 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_output(const OutputEdge& edge) const {
|
||||
std::string onnx_editor::ONNXModelEditor::get_target_tensor_name(const OutputEdge& edge) const {
|
||||
update_mapper_if_needed();
|
||||
const auto& port_name = m_pimpl->m_edge_mapper.get_output_port_name(edge);
|
||||
if (port_name.empty()) {
|
||||
return m_pimpl->m_edge_mapper.get_target_tensor_name(edge);
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_output(const OutputEdge& edge) const {
|
||||
const auto& tensor_name = get_target_tensor_name(edge);
|
||||
if (tensor_name.empty()) {
|
||||
return false;
|
||||
} else {
|
||||
const auto& outputs = model_outputs();
|
||||
return std::count(std::begin(outputs), std::end(outputs), port_name) > 0;
|
||||
return std::count(std::begin(outputs), std::end(outputs), tensor_name) > 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,9 +100,21 @@ public:
|
||||
/// instance of the model editor.
|
||||
std::vector<std::string> model_outputs() const;
|
||||
|
||||
/// \brief Get name of the tensor which is the source of the input edge.
|
||||
///
|
||||
/// \note Empty string is returned if the tensor name is not found.
|
||||
///
|
||||
std::string get_source_tensor_name(const InputEdge& edge) const;
|
||||
|
||||
/// \brief Returns true if input edge is input of the model. Otherwise false.
|
||||
bool is_input(const InputEdge& edge) const;
|
||||
|
||||
/// \brief Get name of the tensor which is the target of the output edge.
|
||||
///
|
||||
/// \note Empty string is returned if the tensor name is not found.
|
||||
///
|
||||
std::string get_target_tensor_name(const OutputEdge& edge) const;
|
||||
|
||||
/// \brief Returns true if output edge is input of the model. Otherwise false.
|
||||
bool is_output(const OutputEdge& edge) const;
|
||||
|
||||
|
@ -34,6 +34,15 @@ bool PlaceInputEdgeONNX::is_equal(Place::Ptr another) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlaceInputEdgeONNX::is_equal_data(Place::Ptr another) const {
|
||||
return get_source_tensor()->is_equal_data(another);
|
||||
}
|
||||
|
||||
Place::Ptr PlaceInputEdgeONNX::get_source_tensor() const {
|
||||
const auto tensor_name = m_editor->get_source_tensor_name(m_edge);
|
||||
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
|
||||
}
|
||||
|
||||
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_edge{edge},
|
||||
@ -59,6 +68,15 @@ bool PlaceOutputEdgeONNX::is_equal(Place::Ptr another) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlaceOutputEdgeONNX::is_equal_data(Place::Ptr another) const {
|
||||
return get_target_tensor()->is_equal_data(another);
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOutputEdgeONNX::get_target_tensor() const {
|
||||
const auto tensor_name = m_editor->get_target_tensor_name(m_edge);
|
||||
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
|
||||
}
|
||||
|
||||
PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_name(name),
|
||||
m_editor(editor) {}
|
||||
@ -68,6 +86,8 @@ std::vector<std::string> PlaceTensorONNX::get_names() const {
|
||||
}
|
||||
|
||||
Place::Ptr PlaceTensorONNX::get_producing_port() const {
|
||||
FRONT_END_GENERAL_CHECK(!is_input(),
|
||||
"Tensor: " + m_name + " is an input of the model and doesn't have producing port.");
|
||||
return std::make_shared<PlaceOutputEdgeONNX>(m_editor->find_output_edge(m_name), m_editor);
|
||||
}
|
||||
|
||||
@ -102,3 +122,14 @@ bool PlaceTensorONNX::is_equal(Place::Ptr another) const {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlaceTensorONNX::is_equal_data(Place::Ptr another) const {
|
||||
const auto consuming_ports = get_consuming_ports();
|
||||
const auto eq_to_consuming_port = [&consuming_ports](const Ptr& another) {
|
||||
return std::any_of(consuming_ports.begin(), consuming_ports.end(), [&another](const Ptr& place) {
|
||||
return place->is_equal(another);
|
||||
});
|
||||
};
|
||||
return is_equal(another) || (is_input() ? false : get_producing_port()->is_equal(another)) ||
|
||||
eq_to_consuming_port(another);
|
||||
}
|
||||
|
@ -22,6 +22,10 @@ public:
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
bool is_equal_data(Place::Ptr another) const override;
|
||||
|
||||
Place::Ptr get_source_tensor() const override;
|
||||
|
||||
private:
|
||||
onnx_editor::InputEdge m_edge;
|
||||
const std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
@ -39,6 +43,10 @@ public:
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
bool is_equal_data(Place::Ptr another) const override;
|
||||
|
||||
Place::Ptr get_target_tensor() const override;
|
||||
|
||||
private:
|
||||
onnx_editor::OutputEdge m_edge;
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
@ -62,6 +70,8 @@ public:
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
bool is_equal_data(Place::Ptr another) const override;
|
||||
|
||||
private:
|
||||
std::string m_name;
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
|
@ -45,7 +45,8 @@ def create_test_onnx_models():
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
|
||||
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph
|
||||
input_tensors = [
|
||||
@ -56,7 +57,8 @@ def create_test_onnx_models():
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["extract_subgraph.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph 2
|
||||
input_tensors = [
|
||||
@ -69,7 +71,8 @@ def create_test_onnx_models():
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["extract_subgraph_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph 3
|
||||
input_tensors = [
|
||||
@ -82,7 +85,8 @@ def create_test_onnx_models():
|
||||
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
|
||||
outputs=["out1", "out2"], name="split1", axis=0)
|
||||
graph = make_graph([expected_split], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_3.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["extract_subgraph_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph 4
|
||||
input_tensors = [
|
||||
@ -100,7 +104,8 @@ def create_test_onnx_models():
|
||||
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
|
||||
outputs=["out4"])
|
||||
graph = make_graph([expected_split, expected_mul], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for test_override_all_outputs
|
||||
input_tensors = [
|
||||
@ -113,7 +118,8 @@ def create_test_onnx_models():
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
|
||||
models["test_override_all_outputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["test_override_all_outputs.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for test_override_all_outputs 2
|
||||
input_tensors = [
|
||||
@ -124,7 +130,8 @@ def create_test_onnx_models():
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, mul], "test_graph", input_tensors, output_tensors)
|
||||
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for test_override_all_inputs
|
||||
input_tensors = [
|
||||
@ -144,7 +151,8 @@ def create_test_onnx_models():
|
||||
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
|
||||
outputs=["out4"])
|
||||
graph = make_graph([expected_split, relu, expected_mul], "test_graph", input_tensors, output_tensors)
|
||||
models["test_override_all_inputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["test_override_all_inputs.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# test partial shape
|
||||
input_tensors = [
|
||||
@ -159,7 +167,8 @@ def create_test_onnx_models():
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (8, 16)),
|
||||
]
|
||||
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
|
||||
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
return models
|
||||
|
||||
@ -530,6 +539,45 @@ def test_is_equal():
|
||||
assert not place8.is_equal(place2)
|
||||
|
||||
|
||||
def test_is_equal_data():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="in1")
|
||||
assert place1.is_equal_data(place1)
|
||||
|
||||
place2 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
assert place2.is_equal_data(place2)
|
||||
|
||||
place3 = model.get_place_by_tensor_name(tensorName="in2")
|
||||
assert not place1.is_equal_data(place3)
|
||||
assert not place2.is_equal_data(place1)
|
||||
|
||||
place4 = place2.get_producing_port()
|
||||
assert place2.is_equal_data(place4)
|
||||
|
||||
place5 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
assert place2.is_equal_data(place5)
|
||||
assert place4.is_equal_data(place5)
|
||||
|
||||
place6 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
|
||||
assert place6.is_equal_data(place5)
|
||||
|
||||
place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
assert place7.is_equal_data(place7)
|
||||
|
||||
place8 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
place9 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
place10 = place8.get_producing_port()
|
||||
assert not place8.is_equal_data(place9)
|
||||
assert not place9.is_equal_data(place10)
|
||||
assert place8.is_equal_data(place10)
|
||||
|
||||
|
||||
def test_get_place_by_tensor_name():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
|
@ -1019,6 +1019,16 @@ NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers_empty_result) {
|
||||
EXPECT_EQ(output_consumers.size(), 0);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_inputs_with_the_same_name) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")};
|
||||
|
||||
std::vector<InputEdge> output_consumers = editor.find_output_consumers("X");
|
||||
EXPECT_EQ(output_consumers[0].m_node_idx, 1);
|
||||
EXPECT_EQ(output_consumers[0].m_port_idx, 0);
|
||||
EXPECT_EQ(output_consumers[1].m_node_idx, 1);
|
||||
EXPECT_EQ(output_consumers[1].m_port_idx, 1);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
@ -1204,6 +1214,19 @@ NGRAPH_TEST(onnx_editor, cut_operator_with_no_schema) {
|
||||
EXPECT_TRUE(result.is_ok) << result.error_message;
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, get_source_tensor_name) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{0, 0}), "in1");
|
||||
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{1, 0}), "relu1");
|
||||
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{1, 1}), "in2");
|
||||
const auto edge1 = editor.find_input_edge(EditorOutput{"conv1"}, 1);
|
||||
EXPECT_EQ(editor.get_source_tensor_name(edge1), "in4");
|
||||
const auto edge2 = editor.find_input_edge(EditorOutput{"split2"}, 0);
|
||||
EXPECT_EQ(editor.get_source_tensor_name(edge2), "add2");
|
||||
EXPECT_EQ(editor.get_source_tensor_name(InputEdge{999, 999}), "");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, is_model_input) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
@ -1221,6 +1244,17 @@ NGRAPH_TEST(onnx_editor, is_model_input) {
|
||||
EXPECT_FALSE(editor.is_input(edge3));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, get_target_tensor_name) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{0, 0}), "relu1");
|
||||
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{1, 0}), "add1");
|
||||
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{4, 0}), "mul2");
|
||||
const auto edge1 = editor.find_output_edge("split1");
|
||||
EXPECT_EQ(editor.get_target_tensor_name(edge1), "split1");
|
||||
EXPECT_EQ(editor.get_target_tensor_name(OutputEdge{999, 999}), "");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, is_model_output) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user