Implemented missing methods of ONNX Place API (#7518)

* Implemented get_consuming_ports, get_producing_operation and is_equal for PlaceOpONNX

* fixed unambiguous_node_check

* removed PlaceTensorONNX::get_input_port

* added PlaceOpONNX::is_input, PlaceOpONNX::is_output

* fixed python styles

* added get_consuming_operations implementation

* added missing get_consuming_operations for PlaceOpONNX

* added missing get_target_tensor for PlaceOpONNX

* changed place spec

* add support of get_source_tensor

* add support of get_producing_operation for PlaceOpONNX

* add support of get_producing_port for PlaceInputEdgeONNX

* fixed python styles

* missing ref in std::transform
This commit is contained in:
Mateusz Bencer 2021-09-23 08:47:33 +02:00 committed by GitHub
parent d7dfce2091
commit fb11560b82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 864 additions and 133 deletions

View File

@ -103,7 +103,7 @@ public:
///
/// \return A vector with all operation node references that consumes data from this
/// place
virtual std::vector<Ptr> get_consuming_operations(const std::string& outputPortName) const;
virtual std::vector<Ptr> get_consuming_operations(const std::string& outputName) const;
/// \brief Returns references to all operation nodes that consume data from this place
/// for specified output port
@ -127,16 +127,14 @@ public:
/// \return A tensor place which hold the resulting value for this place
virtual Ptr get_target_tensor() const;
/// \brief Returns a tensor place that gets data from this place; applicable for
/// operations, output ports and output edges which have only one output port
/// \brief Returns a tensor place that gets data from this place; applicable for operations
///
/// \param outputName Name of output port group
///
/// \return A tensor place which hold the resulting value for this place
virtual Ptr get_target_tensor(const std::string& outputName) const;
/// \brief Returns a tensor place that gets data from this place; applicable for
/// operations, output ports and output edges which have only one output port
/// \brief Returns a tensor place that gets data from this place; applicable for operations
///
/// \param outputName Name of output port group, each group can have multiple ports
///
@ -146,8 +144,7 @@ public:
/// \return A tensor place which hold the resulting value for this place
virtual Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const;
/// \brief Returns a tensor place that gets data from this place; applicable for
/// operations, output ports and output edges
/// \brief Returns a tensor place that gets data from this place; applicable for operations
///
/// \param output_port_index Output port index if the current place is an operation node
/// and has multiple output ports
@ -161,24 +158,21 @@ public:
/// \return A tensor place which supplies data for this place
virtual Ptr get_source_tensor() const;
/// \brief Returns a tensor place that supplies data for this place; applicable for
/// operations, input ports and input edges
/// \brief Returns a tensor place that supplies data for this place; applicable for operations
///
/// \param input_port_index Input port index for operational nodes.
///
/// \return A tensor place which supplies data for this place
virtual Ptr get_source_tensor(int input_port_index) const;
/// \brief Returns a tensor place that supplies data for this place; applicable for
/// operations, input ports and input edges
/// \brief Returns a tensor place that supplies data for this place; applicable for operations
///
/// \param inputName Name of input port group
///
/// \return A tensor place which supplies data for this place
virtual Ptr get_source_tensor(const std::string& inputName) const;
/// \brief Returns a tensor place that supplies data for this place; applicable for
/// operations, input ports and input edges
/// \brief Returns a tensor place that supplies data for this place; applicable for operations
///
/// \param inputName If a given place is itself an operation node, this specifies name
/// of output port group, each group can have multiple ports

View File

@ -95,37 +95,41 @@ std::vector<int> onnx_editor::EdgeMapper::get_node_input_indexes(int node_index,
}
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const {
// identification can be both based on node name and output name
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
int node_index = -1;
if (node_indexes.size() == 1) {
node_index = node_indexes[0];
} else if (node_indexes.empty()) {
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" was not found");
} else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches input name
int matched_inputs_number = 0;
for (const auto& index : node_indexes) {
if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) {
node_index = index;
++matched_inputs_number;
}
}
if (matched_inputs_number == 0) {
throw ngraph_error("Input edge described by: " + node.m_node_name + " and input name: " + in.m_input_name +
int node_index = node.m_node_index;
if (node_index == -1) { // the node index is not provided
// identification can be both based on node name and output name (if the node index is not provided)
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
if (node_indexes.size() == 1) {
node_index = node_indexes[0];
} else if (node_indexes.empty()) {
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" was not found");
} else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches input name
int matched_inputs_number = 0;
for (const auto& index : node_indexes) {
if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) {
node_index = index;
++matched_inputs_number;
}
}
if (matched_inputs_number == 0) {
throw ngraph_error("Input edge described by: " + node.m_node_name +
" and input name: " + in.m_input_name + " was not found");
}
if (matched_inputs_number > 1) {
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name +
" are ambiguous to determine input edge");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " +
std::to_string(in.m_input_index) + " are ambiguous to determine input edge");
}
if (matched_inputs_number > 1) {
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name +
" are ambiguous to determine input edge");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " +
std::to_string(in.m_input_index) + " are ambiguous to determine input edge");
} else { // the node index is provided
check_node_index(node_index);
}
if (in.m_input_index != -1) // input index is set
{
@ -146,33 +150,38 @@ InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const
}
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node, const EditorOutput& out) const {
// identification can be both based on node name and output name
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
int node_index = -1;
if (node_indexes.size() == 1) {
node_index = node_indexes[0];
} else if (node_indexes.empty()) {
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" was not found");
} else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches output name
int matched_outputs_number = 0;
for (const auto& index : node_indexes) {
if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) > 0) {
node_index = index;
++matched_outputs_number;
int node_index = node_index = node.m_node_index;
if (node_index == -1) { // the node index is not provided
// identification can be both based on node name and output name (if the node index is not provided)
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
if (node_indexes.size() == 1) {
node_index = node_indexes[0];
} else if (node_indexes.empty()) {
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" was not found");
} else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches output name
int matched_outputs_number = 0;
for (const auto& index : node_indexes) {
if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) >
0) {
node_index = index;
++matched_outputs_number;
}
}
if (matched_outputs_number == 0) {
throw ngraph_error("Output edge described by: " + node.m_node_name +
" and output name: " + out.m_output_name + " was not found");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " +
std::to_string(out.m_output_index) + " are ambiguous to determine output edge");
}
if (matched_outputs_number == 0) {
throw ngraph_error("Output edge described by: " + node.m_node_name +
" and output name: " + out.m_output_name + " was not found");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " +
std::to_string(out.m_output_index) + " are ambiguous to determine output edge");
} else { // the node index is provided
check_node_index(node_index);
}
if (out.m_output_index != -1) // output index is set
{
@ -197,16 +206,45 @@ std::vector<InputEdge> onnx_editor::EdgeMapper::find_output_consumers(const std:
const auto node_idx = it->second;
const auto port_indexes = get_node_input_indexes(node_idx, output_name);
for (const auto& idx : port_indexes) {
input_edges.push_back(InputEdge{node_idx, idx});
const auto consumer_edge = InputEdge{node_idx, idx};
if (std::find_if(std::begin(input_edges), std::end(input_edges), [&consumer_edge](const InputEdge& edge) {
return edge.m_node_idx == consumer_edge.m_node_idx && edge.m_port_idx == consumer_edge.m_port_idx;
}) == std::end(input_edges)) {
// only unique
input_edges.push_back(consumer_edge);
}
}
}
return input_edges;
}
bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const {
if (node.m_node_index >= 0 && node.m_node_index < static_cast<int>(m_node_inputs.size())) {
return true;
}
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
}
namespace {
void check_node(bool condition, const EditorNode& node) {
NGRAPH_CHECK(condition,
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
", node_index: " + (node.m_node_index == -1 ? "not_given" : std::to_string(node.m_node_index)) +
" is ambiguous");
}
} // namespace
int onnx_editor::EdgeMapper::get_node_index(const EditorNode& node) const {
if (node.m_node_index != -1) { // the node index provided
check_node_index(node.m_node_index);
return node.m_node_index;
}
const auto indexes = find_node_indexes(node.m_node_name, node.m_output_name);
check_node(indexes.size() == 1, node);
return indexes[0];
}
bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const {
if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index)) {
return true;
@ -218,20 +256,25 @@ bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) co
}
std::vector<std::string> onnx_editor::EdgeMapper::get_input_ports(const EditorNode& node) const {
NGRAPH_CHECK(is_correct_and_unambiguous_node(node),
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" is ambiguous");
const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
check_node(is_correct_and_unambiguous_node(node), node);
auto node_index = node.m_node_index;
if (node_index == -1) { // the node index is provided
node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
} else {
check_node_index(node_index);
}
return m_node_inputs[node_index];
}
std::vector<std::string> onnx_editor::EdgeMapper::get_output_ports(const EditorNode& node) const {
NGRAPH_CHECK(is_correct_and_unambiguous_node(node),
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" is ambiguous");
const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
check_node(is_correct_and_unambiguous_node(node), node);
auto node_index = node.m_node_index;
if (node_index == -1) // the node index is provided
{
node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
} else {
check_node_index(node_index);
}
return m_node_outputs[node_index];
}
@ -250,3 +293,8 @@ std::string onnx_editor::EdgeMapper::get_target_tensor_name(const OutputEdge& ed
}
return "";
}
void onnx_editor::EdgeMapper::check_node_index(int node_index) const {
NGRAPH_CHECK(node_index >= 0 && node_index < static_cast<int>(m_node_inputs.size()),
"Provided node index: " + std::to_string(node_index) + " is out of scope");
}

View File

@ -92,6 +92,16 @@ public:
///
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
/// \brief Returns index (position) of provided node in the graph
/// in topological order.
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
/// \note The exception will be thrown if the provided node is ambiguous.
///
int get_node_index(const EditorNode& node) const;
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
///
/// \param name The name of tensor in a graph.
@ -130,6 +140,7 @@ private:
// note: a single node can have more than one inputs with the same name
std::vector<int> get_node_input_indexes(int node_index, const std::string& input_name) const;
int get_node_output_idx(int node_index, const std::string& output_name) const;
void check_node_index(int node_index) const;
std::vector<std::vector<std::string>> m_node_inputs;
std::vector<std::vector<std::string>> m_node_outputs;

View File

@ -444,6 +444,11 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
}
int onnx_editor::ONNXModelEditor::get_node_index(const EditorNode& node) const {
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.get_node_index(node);
}
bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const {
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.is_correct_tensor_name(name);

View File

@ -198,6 +198,16 @@ public:
///
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
/// \brief Returns index (position) of provided node in the graph
/// in topological order.
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
/// \note The exception will be thrown if the provided node is ambiguous.
///
int get_node_index(const EditorNode& node) const;
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
///
/// \param name The name of tensor in a graph.

View File

@ -114,11 +114,14 @@ struct EditorOutput {
/// You can indicate test_node by name as EditorNode("test_node")
/// or by assigned output as EditorNode(EditorOutput("out1"))
/// or EditorNode(EditorOutput("out2"))
/// or you can determine the node by postition of a node in an ONNX graph (in topological order).
struct EditorNode {
EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {}
EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {}
EditorNode(const int node_index) : m_node_index{node_index} {}
const std::string m_node_name = "";
const std::string m_output_name = "";
const int m_node_index = -1;
};
} // namespace onnx_editor
} // namespace ngraph

View File

@ -47,6 +47,18 @@ Place::Ptr PlaceInputEdgeONNX::get_source_tensor() const {
return std::make_shared<PlaceTensorONNX>(m_editor->get_source_tensor_name(m_edge), m_editor);
}
std::vector<Place::Ptr> PlaceInputEdgeONNX::get_consuming_operations() const {
return {std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor)};
}
Place::Ptr PlaceInputEdgeONNX::get_producing_operation() const {
return get_source_tensor()->get_producing_operation();
}
Place::Ptr PlaceInputEdgeONNX::get_producing_port() const {
return get_source_tensor()->get_producing_port();
}
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge},
@ -85,6 +97,18 @@ Place::Ptr PlaceOutputEdgeONNX::get_target_tensor() const {
return std::make_shared<PlaceTensorONNX>(m_editor->get_target_tensor_name(m_edge), m_editor);
}
std::vector<Place::Ptr> PlaceOutputEdgeONNX::get_consuming_ports() const {
return get_target_tensor()->get_consuming_ports();
}
Place::Ptr PlaceOutputEdgeONNX::get_producing_operation() const {
return std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor);
}
std::vector<Place::Ptr> PlaceOutputEdgeONNX::get_consuming_operations() const {
return get_target_tensor()->get_consuming_operations();
}
PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_name{name},
m_editor{std::move(editor)} {}
@ -112,10 +136,8 @@ std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_ports() const {
return ret;
}
Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const {
return std::make_shared<PlaceInputEdgeONNX>(
m_editor->find_input_edge(onnx_editor::EditorOutput(m_name), onnx_editor::EditorInput(input_port_index)),
m_editor);
Place::Ptr PlaceTensorONNX::get_producing_operation() const {
return get_producing_port()->get_producing_operation();
}
bool PlaceTensorONNX::is_input() const {
@ -146,6 +168,19 @@ bool PlaceTensorONNX::is_equal_data(Place::Ptr another) const {
eq_to_consuming_port(another);
}
std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_operations() const {
std::vector<Place::Ptr> consuming_ports = get_consuming_ports();
std::vector<Place::Ptr> consuming_ops;
std::transform(std::begin(consuming_ports),
std::end(consuming_ports),
std::back_inserter(consuming_ops),
[](const Place::Ptr& place) {
return place->get_consuming_operations().at(0);
});
return consuming_ops;
}
PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_node{node},
m_editor{std::move(editor)} {}
@ -158,6 +193,10 @@ std::vector<std::string> PlaceOpONNX::get_names() const {
return {m_node.m_node_name};
}
onnx_editor::EditorNode PlaceOpONNX::get_editor_node() const {
return m_node;
}
Place::Ptr PlaceOpONNX::get_output_port() const {
if (m_editor->get_output_ports(m_node).size() == 1) {
return get_output_port(0);
@ -209,3 +248,133 @@ Place::Ptr PlaceOpONNX::get_input_port(const std::string& input_name) const {
}
return nullptr;
}
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_ports() const {
std::vector<Place::Ptr> consuming_ports;
const auto out_ports_size = m_editor->get_output_ports(m_node).size();
for (int out_idx = 0; out_idx < out_ports_size; ++out_idx) {
auto consuming_ops_out = get_output_port(out_idx)->get_consuming_ports();
consuming_ports.insert(consuming_ports.end(), consuming_ops_out.begin(), consuming_ops_out.end());
}
return consuming_ports;
}
namespace {
std::vector<Place::Ptr> get_consuming_ops(std::vector<Place::Ptr> input_ports) {
std::vector<Place::Ptr> consuming_ops;
std::transform(std::begin(input_ports),
std::end(input_ports),
std::back_inserter(consuming_ops),
[](const Place::Ptr place) {
return place->get_consuming_operations().at(0);
});
return consuming_ops;
}
} // namespace
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations() const {
std::vector<Place::Ptr> consuming_ports = get_consuming_ports();
return get_consuming_ops(consuming_ports);
}
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations(int output_port_index) const {
std::vector<Place::Ptr> consuming_ports = get_output_port(output_port_index)->get_consuming_ports();
return get_consuming_ops(consuming_ports);
}
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations(const std::string& output_port_name) const {
std::vector<Place::Ptr> consuming_ports = get_output_port(output_port_name)->get_consuming_ports();
return get_consuming_ops(consuming_ports);
}
Place::Ptr PlaceOpONNX::get_producing_operation() const {
const auto input_port = get_input_port();
if (input_port != nullptr) {
return input_port->get_producing_operation();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_producing_operation(int input_port_index) const {
const auto input_port = get_input_port(input_port_index);
if (input_port != nullptr) {
return input_port->get_producing_operation();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_producing_operation(const std::string& input_port_name) const {
const auto input_port = get_input_port(input_port_name);
if (input_port != nullptr) {
return input_port->get_producing_operation();
}
return nullptr;
}
bool PlaceOpONNX::is_equal(Place::Ptr another) const {
if (const auto place_op = std::dynamic_pointer_cast<PlaceOpONNX>(another)) {
const auto& another_node = place_op->get_editor_node();
if (m_editor->is_correct_and_unambiguous_node(m_node) ||
m_editor->is_correct_and_unambiguous_node(another_node)) {
return m_editor->get_node_index(m_node) == m_editor->get_node_index(another_node);
}
}
return false;
}
Place::Ptr PlaceOpONNX::get_target_tensor() const {
const auto output_port = get_output_port();
if (output_port != nullptr) {
return output_port->get_target_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_target_tensor(int output_port_index) const {
const auto output_port = get_output_port(output_port_index);
if (output_port != nullptr) {
return output_port->get_target_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_target_tensor(const std::string& output_name) const {
const auto output_port = get_output_port(output_name);
if (output_port != nullptr) {
return output_port->get_target_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_source_tensor() const {
const auto input_port = get_input_port();
if (input_port != nullptr) {
return input_port->get_source_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_source_tensor(int input_port_index) const {
const auto input_port = get_input_port(input_port_index);
if (input_port != nullptr) {
return input_port->get_source_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_source_tensor(const std::string& input_name) const {
const auto input_port = get_input_port(input_name);
if (input_port != nullptr) {
return input_port->get_source_tensor();
}
return nullptr;
}
bool PlaceOpONNX::is_input() const {
return false;
}
bool PlaceOpONNX::is_output() const {
return false;
}

View File

@ -7,6 +7,7 @@
#include <editor.hpp>
#include <frontend_manager/place.hpp>
#include <memory>
#include <sstream>
namespace ngraph {
namespace frontend {
@ -15,17 +16,18 @@ public:
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
PlaceInputEdgeONNX(onnx_editor::InputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
// internal usage
onnx_editor::InputEdge get_input_edge() const;
// external usage
bool is_input() const override;
bool is_output() const override;
bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override;
Place::Ptr get_source_tensor() const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
Place::Ptr get_producing_operation() const override;
Place::Ptr get_producing_port() const override;
private:
onnx_editor::InputEdge m_edge;
@ -37,17 +39,18 @@ public:
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
PlaceOutputEdgeONNX(onnx_editor::OutputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
// internal usage
onnx_editor::OutputEdge get_output_edge() const;
// external usage
bool is_input() const override;
bool is_output() const override;
bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override;
Place::Ptr get_target_tensor() const override;
std::vector<Place::Ptr> get_consuming_ports() const override;
Place::Ptr get_producing_operation() const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
private:
onnx_editor::OutputEdge m_edge;
@ -59,21 +62,16 @@ public:
PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
PlaceTensorONNX(std::string&& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
// external usage
std::vector<std::string> get_names() const override;
Place::Ptr get_producing_port() const override;
std::vector<Place::Ptr> get_consuming_ports() const override;
Ptr get_input_port(int input_port_index) const override;
Place::Ptr get_producing_operation() const override;
bool is_input() const override;
bool is_output() const override;
bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
private:
std::string m_name;
@ -86,6 +84,10 @@ public:
PlaceOpONNX(onnx_editor::EditorNode&& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
std::vector<std::string> get_names() const override;
// internal usage
onnx_editor::EditorNode get_editor_node() const;
// external usage
Place::Ptr get_output_port() const override;
Place::Ptr get_output_port(int output_port_index) const override;
Place::Ptr get_output_port(const std::string& output_port_name) const override;
@ -94,6 +96,27 @@ public:
Place::Ptr get_input_port(int input_port_index) const override;
Place::Ptr get_input_port(const std::string& input_name) const override;
std::vector<Place::Ptr> get_consuming_ports() const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
std::vector<Place::Ptr> get_consuming_operations(int output_port_index) const override;
std::vector<Place::Ptr> get_consuming_operations(const std::string& output_port_name) const override;
Place::Ptr get_producing_operation() const override;
Place::Ptr get_producing_operation(int input_port_index) const override;
Place::Ptr get_producing_operation(const std::string& input_port_name) const override;
Place::Place::Ptr get_target_tensor() const override;
Place::Ptr get_target_tensor(int output_port_index) const override;
Place::Ptr get_target_tensor(const std::string& output_name) const override;
Place::Place::Ptr get_source_tensor() const override;
Place::Ptr get_source_tensor(int input_port_index) const override;
Place::Ptr get_source_tensor(const std::string& input_name) const override;
bool is_equal(Place::Ptr another) const override;
bool is_input() const override;
bool is_output() const override;
private:
onnx_editor::EditorNode m_node;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;

View File

@ -693,6 +693,26 @@ NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_ind
EXPECT_EQ(edge2.m_new_input_name, "custom_input_name_2");
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_index) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
const InputEdge edge = editor.find_input_edge(EditorNode{0}, EditorInput{0, "custom_input_name_1"});
EXPECT_EQ(edge.m_node_idx, 0);
EXPECT_EQ(edge.m_port_idx, 0);
EXPECT_EQ(edge.m_new_input_name, "custom_input_name_1");
const InputEdge edge2 = editor.find_input_edge(EditorNode{5}, EditorInput{0});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_port_idx, 0);
try {
editor.find_input_edge(EditorNode{99}, EditorInput{"conv1/7x7_s2_1"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
@ -767,6 +787,21 @@ NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_i
EXPECT_EQ(edge2.m_port_idx, 1);
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_index) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
const OutputEdge edge = editor.find_output_edge(EditorNode{5}, EditorOutput{1});
EXPECT_EQ(edge.m_node_idx, 5);
EXPECT_EQ(edge.m_port_idx, 1);
try {
editor.find_output_edge(EditorNode{99}, EditorOutput{"conv1/7x7_s2_1"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.onnx")};
@ -1044,6 +1079,12 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{2});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{99});
EXPECT_EQ(is_correct_node, false);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}});
EXPECT_EQ(is_correct_node, false);
@ -1054,6 +1095,32 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
EXPECT_EQ(is_correct_node, false);
}
NGRAPH_TEST(onnx_editor, editor_api_get_node_index) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
EXPECT_EQ(editor.get_node_index(EditorNode{2}), 2);
EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"relu1"}}), 0);
EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"split2"}}), 5);
EXPECT_EQ(editor.get_node_index(EditorNode{"relu1_name"}), 0);
try {
editor.get_node_index(EditorNode{99});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
}
try {
editor.get_node_index(EditorNode{"add_ambiguous_name"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(
msg.find(
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_input_edge_from_tensor_with_single_consumer) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")};
@ -1365,15 +1432,18 @@ NGRAPH_TEST(onnx_editor, get_input_ports) {
editor.get_input_ports(EditorNode{"add_ambiguous_name"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") !=
std::string::npos);
EXPECT_TRUE(
msg.find(
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
}
try {
editor.get_input_ports(EditorNode{""});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") !=
std::string::npos);
EXPECT_TRUE(
msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, get_output_ports) {
@ -1393,14 +1463,17 @@ NGRAPH_TEST(onnx_editor, get_output_ports) {
editor.get_output_ports(EditorNode{"add_ambiguous_name"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") !=
std::string::npos);
EXPECT_TRUE(
msg.find(
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
}
try {
editor.get_output_ports(EditorNode{""});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") !=
std::string::npos);
EXPECT_TRUE(
msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
}
}

View File

@ -89,13 +89,23 @@ void regclass_pyngraph_Place(py::module m) {
place.def(
"get_consuming_operations",
[](const ngraph::frontend::Place& self, py::object outputPortIndex) {
if (outputPortIndex == py::none()) {
return self.get_consuming_operations();
[](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) {
if (outputName == py::none()) {
if (outputPortIndex == py::none()) {
return self.get_consuming_operations();
} else {
return self.get_consuming_operations(py::cast<int>(outputPortIndex));
}
} else {
return self.get_consuming_operations(py::cast<int>(outputPortIndex));
if (outputPortIndex == py::none()) {
return self.get_consuming_operations(py::cast<std::string>(outputName));
} else {
return self.get_consuming_operations(py::cast<std::string>(outputName),
py::cast<int>(outputPortIndex));
}
}
},
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(),
R"(
Returns references to all operation nodes that consume data from this place for specified output port.
@ -103,6 +113,8 @@ void regclass_pyngraph_Place(py::module m) {
Parameters
----------
outputName : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int
If place is an operational node it specifies which output port should be considered
May not be set if node has only one output port.
@ -115,13 +127,22 @@ void regclass_pyngraph_Place(py::module m) {
place.def(
"get_target_tensor",
[](const ngraph::frontend::Place& self, py::object outputPortIndex) {
if (outputPortIndex == py::none()) {
return self.get_target_tensor();
[](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) {
if (outputName == py::none()) {
if (outputPortIndex == py::none()) {
return self.get_target_tensor();
} else {
return self.get_target_tensor(py::cast<int>(outputPortIndex));
}
} else {
return self.get_target_tensor(py::cast<int>(outputPortIndex));
if (outputPortIndex == py::none()) {
return self.get_target_tensor(py::cast<std::string>(outputName));
} else {
return self.get_target_tensor(py::cast<std::string>(outputName), py::cast<int>(outputPortIndex));
}
}
},
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(),
R"(
Returns a tensor place that gets data from this place; applicable for operations,
@ -129,6 +150,8 @@ void regclass_pyngraph_Place(py::module m) {
Parameters
----------
outputName : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int
Output port index if the current place is an operation node and has multiple output ports.
May not be set if place has only one output port.
@ -141,19 +164,31 @@ void regclass_pyngraph_Place(py::module m) {
place.def(
"get_producing_operation",
[](const ngraph::frontend::Place& self, py::object inputPortIndex) {
if (inputPortIndex == py::none()) {
return self.get_producing_operation();
[](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) {
if (inputName == py::none()) {
if (inputPortIndex == py::none()) {
return self.get_producing_operation();
} else {
return self.get_producing_operation(py::cast<int>(inputPortIndex));
}
} else {
return self.get_producing_operation(py::cast<int>(inputPortIndex));
if (inputPortIndex == py::none()) {
return self.get_producing_operation(py::cast<std::string>(inputName));
} else {
return self.get_producing_operation(py::cast<std::string>(inputName),
py::cast<int>(inputPortIndex));
}
}
},
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(),
R"(
Get an operation node place that immediately produces data for this place.
Parameters
----------
inputName : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int
If a given place is itself an operation node, this specifies a port index.
May not be set if place has only one input port.
@ -260,13 +295,22 @@ void regclass_pyngraph_Place(py::module m) {
place.def(
"get_source_tensor",
[](const ngraph::frontend::Place& self, py::object inputPortIndex) {
if (inputPortIndex == py::none()) {
return self.get_source_tensor();
[](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) {
if (inputName == py::none()) {
if (inputPortIndex == py::none()) {
return self.get_source_tensor();
} else {
return self.get_source_tensor(py::cast<int>(inputPortIndex));
}
} else {
return self.get_source_tensor(py::cast<int>(inputPortIndex));
if (inputPortIndex == py::none()) {
return self.get_source_tensor(py::cast<std::string>(inputName));
} else {
return self.get_source_tensor(py::cast<std::string>(inputName), py::cast<int>(inputPortIndex));
}
}
},
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(),
R"(
Returns a tensor place that supplies data for this place; applicable for operations,
@ -274,6 +318,8 @@ void regclass_pyngraph_Place(py::module m) {
Parameters
----------
inputName : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int
Input port index for operational node. May not be specified if place has only one input port.

View File

@ -107,12 +107,29 @@ public:
std::vector<Place::Ptr> get_consuming_operations() const override {
m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = "";
return {std::make_shared<PlaceMockPy>()};
}
std::vector<Place::Ptr> get_consuming_operations(int outputPortIndex) const override {
m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = outputPortIndex;
m_stat.m_lastArgString = "";
return {std::make_shared<PlaceMockPy>()};
}
std::vector<Place::Ptr> get_consuming_operations(const std::string& outputName) const override {
m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
std::vector<Place::Ptr> get_consuming_operations(const std::string& outputName,
int outputPortIndex) const override {
m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = outputPortIndex;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
@ -128,6 +145,20 @@ public:
return std::make_shared<PlaceMockPy>();
}
Place::Ptr get_target_tensor(const std::string& outputName) const override {
m_stat.m_get_target_tensor++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const override {
m_stat.m_get_target_tensor++;
m_stat.m_lastArgInt = outputPortIndex;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_producing_operation() const override {
m_stat.m_get_producing_operation++;
m_stat.m_lastArgInt = -1;
@ -140,6 +171,20 @@ public:
return std::make_shared<PlaceMockPy>();
}
Place::Ptr get_producing_operation(const std::string& inputName) const override {
m_stat.m_get_producing_operation++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_producing_operation(const std::string& inputName, int inputPortIndex) const override {
m_stat.m_get_producing_operation++;
m_stat.m_lastArgInt = inputPortIndex;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_producing_port() const override {
m_stat.m_get_producing_port++;
return std::make_shared<PlaceMockPy>();
@ -236,6 +281,20 @@ public:
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_source_tensor(const std::string& inputName) const override {
m_stat.m_get_source_tensor++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_source_tensor(const std::string& inputName, int inputPortIndex) const override {
m_stat.m_get_source_tensor++;
m_stat.m_lastArgInt = inputPortIndex;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
//---------------Stat--------------------
PlaceStat get_stat() const {
return m_stat;

View File

@ -9,6 +9,7 @@ from ngraph import PartialShape
from ngraph.frontend import FrontEndManager
# ------Test input model 1------
# in1 in2 in3
# | | |
# \ / |
@ -24,9 +25,32 @@ from ngraph.frontend import FrontEndManager
# / \ |
# out1 out2 out4
#
#
# ------Test input model 2------
# in1 in2
# | |
# \ /
# +--------+
# | Add |
# +--------+
# <add_out>
# |
# +--------+
# | Split |
# |(split2)|
# +--------+
# / \
# <sp_out1> <sp_out2>
# +-------+ +-------+
# | Abs | | Sin |
# | (abs1)| | |
# +------ + +-------+
# | |
# out1 out2
#
def create_test_onnx_models():
models = {}
# Input model
# Input model 1
add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"])
split = onnx.helper.make_node("Split", inputs=["add_out"],
outputs=["out1", "out2"], name="split1", axis=0)
@ -48,6 +72,24 @@ def create_test_onnx_models():
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Input model 2
split_2 = onnx.helper.make_node("Split", inputs=["add_out"],
outputs=["sp_out1", "sp_out2"], name="split2", axis=0)
abs = onnx.helper.make_node("Abs", inputs=["sp_out1"], outputs=["out1"], name="abs1")
sin = onnx.helper.make_node("Sin", inputs=["sp_out2"], outputs=["out2"])
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
]
graph = make_graph([add, split_2, abs, sin], "test_graph_2", input_tensors, output_tensors)
models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
@ -312,8 +354,9 @@ def test_extract_subgraph_4():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_tensor_name(tensorName="out2")
@ -377,8 +420,9 @@ def test_override_all_inputs():
place1 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place4 = model.get_place_by_tensor_name(tensorName="in3")
model.override_all_inputs(inputs=[place1, place2, place3, place4])
result_func = fe.convert(model)
@ -432,11 +476,15 @@ def test_is_input_output():
assert not place3.is_input()
assert not place3.is_output()
place4 = place1 = model.get_place_by_operation_name_and_input_port(
place4 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0)
assert not place4.is_input()
assert not place4.is_output()
place5 = model.get_place_by_operation_name(operationName="split1")
assert not place5.is_input()
assert not place5.is_output()
def test_set_partial_shape():
skip_if_onnx_frontend_is_disabled()
@ -520,12 +568,14 @@ def test_is_equal():
place2 = model.get_place_by_tensor_name(tensorName="out2")
assert place2.is_equal(place2)
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert place3.is_equal(place4)
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place6 = model.get_place_by_tensor_name(tensorName="out1").get_input_port(inputPortIndex=0)
place6 = out1_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert place5.is_equal(place6)
place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port()
@ -538,6 +588,10 @@ def test_is_equal():
assert not place6.is_equal(place7)
assert not place8.is_equal(place2)
place9 = model.get_place_by_operation_name(operationName="split1")
assert place2.get_producing_operation().is_equal(place9)
assert not place9.is_equal(place2)
def test_is_equal_data():
skip_if_onnx_frontend_is_disabled()
@ -560,11 +614,12 @@ def test_is_equal_data():
place4 = place2.get_producing_port()
assert place2.is_equal_data(place4)
place5 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place5 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert place2.is_equal_data(place5)
assert place4.is_equal_data(place5)
place6 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
place6 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
assert place6.is_equal_data(place5)
place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
@ -648,3 +703,198 @@ def test_get_input_port():
assert not split_op.get_input_port(inputPortIndex=1)
assert not split_op.get_input_port(inputName="not_existed")
def test_get_consuming_ports():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
add_tensor_consuming_ports = place1.get_consuming_ports()
assert len(add_tensor_consuming_ports) == 3
place2 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
assert add_tensor_consuming_ports[0].is_equal(place2)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert add_tensor_consuming_ports[1].is_equal(place3)
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
assert add_tensor_consuming_ports[2].is_equal(place4)
add_op_consuming_ports = place1.get_producing_operation().get_consuming_ports()
assert len(add_op_consuming_ports) == len(add_tensor_consuming_ports)
for i in range(len(add_op_consuming_ports)):
assert add_op_consuming_ports[i].is_equal(add_tensor_consuming_ports[i])
def test_get_consuming_ports_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op_consuming_ports = split_op.get_consuming_ports()
assert len(split_op_consuming_ports) == 2
abs_input_port = model.get_place_by_operation_name(operationName="abs1").get_input_port(inputPortIndex=0)
assert split_op_consuming_ports[0].is_equal(abs_input_port)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
sin_input_port = out2_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert split_op_consuming_ports[1].is_equal(sin_input_port)
split_out_port_0 = split_op.get_output_port(outputPortIndex=0)
split_out_port_0_consuming_ports = split_out_port_0.get_consuming_ports()
assert len(split_out_port_0_consuming_ports) == 1
assert split_out_port_0_consuming_ports[0].is_equal(abs_input_port)
def test_get_producing_operation():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_tensor_out_2 = model.get_place_by_tensor_name(tensorName="sp_out2")
split_op = model.get_place_by_operation_name(operationName="split2")
assert split_tensor_out_2.get_producing_operation().is_equal(split_op)
split_op = model.get_place_by_operation_name(operationName="split2")
split_out_port_2 = split_op.get_output_port(outputPortIndex=1)
assert split_out_port_2.get_producing_operation().is_equal(split_op)
def test_get_producing_operation_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
abs_op = model.get_place_by_operation_name(operationName="abs1")
abs_port_0 = abs_op.get_input_port()
split_op = model.get_place_by_operation_name(operationName="split2")
assert abs_port_0.get_producing_operation().is_equal(split_op)
assert abs_op.get_producing_operation().is_equal(split_op)
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
assert not add_op.get_producing_operation()
split_op_producing_op = split_op.get_producing_operation(inputName="add_out")
assert split_op_producing_op.is_equal(add_op)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
sin_op = out2_tensor.get_producing_operation()
assert sin_op.get_producing_operation(inputPortIndex=0).is_equal(split_op)
def test_get_consuming_operations():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op_consuming_ops = split_op.get_consuming_operations()
abs_op = model.get_place_by_operation_name(operationName="abs1")
sin_op = model.get_place_by_tensor_name(tensorName="out2").get_producing_operation()
assert len(split_op_consuming_ops) == 2
assert split_op_consuming_ops[0].is_equal(abs_op)
assert split_op_consuming_ops[1].is_equal(sin_op)
split_op_port = split_op.get_input_port(inputPortIndex=0)
split_op_port_consuming_ops = split_op_port.get_consuming_operations()
assert len(split_op_port_consuming_ops) == 1
assert split_op_port_consuming_ops[0].is_equal(split_op)
add_out_port = model.get_place_by_tensor_name(tensorName="add_out").get_producing_port()
add_out_port_consuming_ops = add_out_port.get_consuming_operations()
assert len(add_out_port_consuming_ops) == 1
assert add_out_port_consuming_ops[0].is_equal(split_op)
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
sp_out2_tensor_consuming_ops = sp_out2_tensor.get_consuming_operations()
assert len(sp_out2_tensor_consuming_ops) == 1
assert sp_out2_tensor_consuming_ops[0].is_equal(sin_op)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
out2_tensor_consuming_ops = out2_tensor.get_consuming_operations()
assert len(out2_tensor_consuming_ops) == 0
out2_port_consuming_ops = out2_tensor.get_producing_port().get_consuming_operations()
assert len(out2_port_consuming_ops) == 0
split_out_1_consuming_ops = split_op.get_consuming_operations(outputPortIndex=1)
assert len(split_out_1_consuming_ops) == 1
split_out_sp_out_2_consuming_ops = split_op.get_consuming_operations(outputName="sp_out2")
assert len(split_out_sp_out_2_consuming_ops) == 1
assert split_out_1_consuming_ops[0].is_equal(split_out_sp_out_2_consuming_ops[0])
assert split_out_1_consuming_ops[0].is_equal(sin_op)
def test_get_target_tensor():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
assert not split_op.get_target_tensor()
split_op_tensor_1 = split_op.get_target_tensor(outputPortIndex=1)
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
assert split_op_tensor_1.is_equal(sp_out2_tensor)
split_tensor_sp_out2 = split_op.get_target_tensor(outputName="sp_out2")
assert split_tensor_sp_out2.is_equal(split_op_tensor_1)
abs_op = model.get_place_by_operation_name(operationName="abs1")
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
assert abs_op.get_target_tensor().is_equal(out1_tensor)
def test_get_source_tensor():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
assert not add_op.get_source_tensor()
add_op_in_tensor_1 = add_op.get_source_tensor(inputPortIndex=1)
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
assert add_op_in_tensor_1.is_equal(in2_tensor)
add_op_in_tensor_in2 = add_op.get_source_tensor(inputName="in2")
assert add_op_in_tensor_in2.is_equal(in2_tensor)
split_op = model.get_place_by_operation_name(operationName="split2")
assert split_op.get_source_tensor().is_equal(add_out_tensor)
def test_get_producing_port():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op_in_port = split_op.get_input_port()
split_op_in_port_prod_port = split_op_in_port.get_producing_port()
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
add_op_out_port = add_op.get_output_port()
assert split_op_in_port_prod_port.is_equal(add_op_out_port)

View File

@ -440,6 +440,16 @@ def test_place_get_consuming_operations():
stat = get_place_stat(place)
assert stat.get_consuming_operations == 2
assert stat.lastArgInt == -1
assert place.get_consuming_operations(outputName="2") is not None
stat = get_place_stat(place)
assert stat.get_consuming_operations == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_consuming_operations(outputName="3", outputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_consuming_operations == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"
@mock_needed
@ -453,6 +463,16 @@ def test_place_get_target_tensor():
stat = get_place_stat(place)
assert stat.get_target_tensor == 2
assert stat.lastArgInt == -1
assert place.get_target_tensor(outputName="2") is not None
stat = get_place_stat(place)
assert stat.get_target_tensor == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_target_tensor(outputName="3", outputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_target_tensor == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"
@mock_needed
@ -466,6 +486,16 @@ def test_place_get_producing_operation():
stat = get_place_stat(place)
assert stat.get_producing_operation == 2
assert stat.lastArgInt == -1
assert place.get_producing_operation(inputName="2") is not None
stat = get_place_stat(place)
assert stat.get_producing_operation == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_producing_operation(inputName="3", inputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_producing_operation == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"
@mock_needed
@ -551,3 +581,13 @@ def test_place_get_source_tensor():
stat = get_place_stat(place)
assert stat.get_source_tensor == 2
assert stat.lastArgInt == 22
assert place.get_source_tensor(inputName="2") is not None
stat = get_place_stat(place)
assert stat.get_source_tensor == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_source_tensor(inputName="3", inputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_source_tensor == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"