Implemented missing methods of ONNX Place API (#7518)

* Implemented get_consuming_ports, get_producing_operation and is_equal for PlaceOpONNX

* fixed unambiguous_node_check

* removed PlaceTensorONNX::get_input_port

* added PlaceOpONNX::is_input, PlaceOpONNX::is_output

* fixed python styles

* added get_consuming_operations implementation

* added missing get_consuming_operations for PlaceOpONNX

* added missing get_target_tensor for PlaceOpONNX

* changed place spec

* add support of get_source_tensor

* add support of get_producing_operation for PlaceOpONNX

* add support of get_producing_port for PlaceInputEdgeONNX

* fixed python styles

* missing ref in std::transform
This commit is contained in:
Mateusz Bencer 2021-09-23 08:47:33 +02:00 committed by GitHub
parent d7dfce2091
commit fb11560b82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 864 additions and 133 deletions

View File

@ -103,7 +103,7 @@ public:
/// ///
/// \return A vector with all operation node references that consumes data from this /// \return A vector with all operation node references that consumes data from this
/// place /// place
virtual std::vector<Ptr> get_consuming_operations(const std::string& outputPortName) const; virtual std::vector<Ptr> get_consuming_operations(const std::string& outputName) const;
/// \brief Returns references to all operation nodes that consume data from this place /// \brief Returns references to all operation nodes that consume data from this place
/// for specified output port /// for specified output port
@ -127,16 +127,14 @@ public:
/// \return A tensor place which hold the resulting value for this place /// \return A tensor place which hold the resulting value for this place
virtual Ptr get_target_tensor() const; virtual Ptr get_target_tensor() const;
/// \brief Returns a tensor place that gets data from this place; applicable for /// \brief Returns a tensor place that gets data from this place; applicable for operations
/// operations, output ports and output edges which have only one output port
/// ///
/// \param outputName Name of output port group /// \param outputName Name of output port group
/// ///
/// \return A tensor place which hold the resulting value for this place /// \return A tensor place which hold the resulting value for this place
virtual Ptr get_target_tensor(const std::string& outputName) const; virtual Ptr get_target_tensor(const std::string& outputName) const;
/// \brief Returns a tensor place that gets data from this place; applicable for /// \brief Returns a tensor place that gets data from this place; applicable for operations
/// operations, output ports and output edges which have only one output port
/// ///
/// \param outputName Name of output port group, each group can have multiple ports /// \param outputName Name of output port group, each group can have multiple ports
/// ///
@ -146,8 +144,7 @@ public:
/// \return A tensor place which hold the resulting value for this place /// \return A tensor place which hold the resulting value for this place
virtual Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const; virtual Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const;
/// \brief Returns a tensor place that gets data from this place; applicable for /// \brief Returns a tensor place that gets data from this place; applicable for operations
/// operations, output ports and output edges
/// ///
/// \param output_port_index Output port index if the current place is an operation node /// \param output_port_index Output port index if the current place is an operation node
/// and has multiple output ports /// and has multiple output ports
@ -161,24 +158,21 @@ public:
/// \return A tensor place which supplies data for this place /// \return A tensor place which supplies data for this place
virtual Ptr get_source_tensor() const; virtual Ptr get_source_tensor() const;
/// \brief Returns a tensor place that supplies data for this place; applicable for /// \brief Returns a tensor place that supplies data for this place; applicable for operations
/// operations, input ports and input edges
/// ///
/// \param input_port_index Input port index for operational nodes. /// \param input_port_index Input port index for operational nodes.
/// ///
/// \return A tensor place which supplies data for this place /// \return A tensor place which supplies data for this place
virtual Ptr get_source_tensor(int input_port_index) const; virtual Ptr get_source_tensor(int input_port_index) const;
/// \brief Returns a tensor place that supplies data for this place; applicable for /// \brief Returns a tensor place that supplies data for this place; applicable for operations
/// operations, input ports and input edges
/// ///
/// \param inputName Name of input port group /// \param inputName Name of input port group
/// ///
/// \return A tensor place which supplies data for this place /// \return A tensor place which supplies data for this place
virtual Ptr get_source_tensor(const std::string& inputName) const; virtual Ptr get_source_tensor(const std::string& inputName) const;
/// \brief Returns a tensor place that supplies data for this place; applicable for /// \brief Returns a tensor place that supplies data for this place; applicable for operations
/// operations, input ports and input edges
/// ///
/// \param inputName If a given place is itself an operation node, this specifies name /// \param inputName If a given place is itself an operation node, this specifies name
/// of output port group, each group can have multiple ports /// of output port group, each group can have multiple ports

View File

@ -95,37 +95,41 @@ std::vector<int> onnx_editor::EdgeMapper::get_node_input_indexes(int node_index,
} }
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const { InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const {
// identification can be both based on node name and output name int node_index = node.m_node_index;
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); if (node_index == -1) { // the node index is not provided
int node_index = -1; // identification can be both based on node name and output name (if the node index is not provided)
if (node_indexes.size() == 1) { const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
node_index = node_indexes[0]; if (node_indexes.size() == 1) {
} else if (node_indexes.empty()) { node_index = node_indexes[0];
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + } else if (node_indexes.empty()) {
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" was not found"); " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
} else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches input name
int matched_inputs_number = 0;
for (const auto& index : node_indexes) {
if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) {
node_index = index;
++matched_inputs_number;
}
}
if (matched_inputs_number == 0) {
throw ngraph_error("Input edge described by: " + node.m_node_name + " and input name: " + in.m_input_name +
" was not found"); " was not found");
} else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches input name
int matched_inputs_number = 0;
for (const auto& index : node_indexes) {
if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) {
node_index = index;
++matched_inputs_number;
}
}
if (matched_inputs_number == 0) {
throw ngraph_error("Input edge described by: " + node.m_node_name +
" and input name: " + in.m_input_name + " was not found");
}
if (matched_inputs_number > 1) {
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name +
" are ambiguous to determine input edge");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " +
std::to_string(in.m_input_index) + " are ambiguous to determine input edge");
} }
if (matched_inputs_number > 1) { } else { // the node index is provided
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name + check_node_index(node_index);
" are ambiguous to determine input edge");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " +
std::to_string(in.m_input_index) + " are ambiguous to determine input edge");
} }
if (in.m_input_index != -1) // input index is set if (in.m_input_index != -1) // input index is set
{ {
@ -146,33 +150,38 @@ InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const
} }
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node, const EditorOutput& out) const { OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node, const EditorOutput& out) const {
// identification can be both based on node name and output name int node_index = node_index = node.m_node_index;
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); if (node_index == -1) { // the node index is not provided
int node_index = -1; // identification can be both based on node name and output name (if the node index is not provided)
if (node_indexes.size() == 1) { const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
node_index = node_indexes[0]; if (node_indexes.size() == 1) {
} else if (node_indexes.empty()) { node_index = node_indexes[0];
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + } else if (node_indexes.empty()) {
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" was not found"); " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
} else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous " was not found");
{ } else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous
// many nodes with the same name {
// check if some of found index matches output name // many nodes with the same name
int matched_outputs_number = 0; // check if some of found index matches output name
for (const auto& index : node_indexes) { int matched_outputs_number = 0;
if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) > 0) { for (const auto& index : node_indexes) {
node_index = index; if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) >
++matched_outputs_number; 0) {
node_index = index;
++matched_outputs_number;
}
} }
if (matched_outputs_number == 0) {
throw ngraph_error("Output edge described by: " + node.m_node_name +
" and output name: " + out.m_output_name + " was not found");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " +
std::to_string(out.m_output_index) + " are ambiguous to determine output edge");
} }
if (matched_outputs_number == 0) { } else { // the node index is provided
throw ngraph_error("Output edge described by: " + node.m_node_name + check_node_index(node_index);
" and output name: " + out.m_output_name + " was not found");
}
} else {
throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " +
std::to_string(out.m_output_index) + " are ambiguous to determine output edge");
} }
if (out.m_output_index != -1) // output index is set if (out.m_output_index != -1) // output index is set
{ {
@ -197,16 +206,45 @@ std::vector<InputEdge> onnx_editor::EdgeMapper::find_output_consumers(const std:
const auto node_idx = it->second; const auto node_idx = it->second;
const auto port_indexes = get_node_input_indexes(node_idx, output_name); const auto port_indexes = get_node_input_indexes(node_idx, output_name);
for (const auto& idx : port_indexes) { for (const auto& idx : port_indexes) {
input_edges.push_back(InputEdge{node_idx, idx}); const auto consumer_edge = InputEdge{node_idx, idx};
if (std::find_if(std::begin(input_edges), std::end(input_edges), [&consumer_edge](const InputEdge& edge) {
return edge.m_node_idx == consumer_edge.m_node_idx && edge.m_port_idx == consumer_edge.m_port_idx;
}) == std::end(input_edges)) {
// only unique
input_edges.push_back(consumer_edge);
}
} }
} }
return input_edges; return input_edges;
} }
bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const { bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const {
if (node.m_node_index >= 0 && node.m_node_index < static_cast<int>(m_node_inputs.size())) {
return true;
}
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1; return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
} }
namespace {
void check_node(bool condition, const EditorNode& node) {
NGRAPH_CHECK(condition,
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
", node_index: " + (node.m_node_index == -1 ? "not_given" : std::to_string(node.m_node_index)) +
" is ambiguous");
}
} // namespace
int onnx_editor::EdgeMapper::get_node_index(const EditorNode& node) const {
if (node.m_node_index != -1) { // the node index provided
check_node_index(node.m_node_index);
return node.m_node_index;
}
const auto indexes = find_node_indexes(node.m_node_name, node.m_output_name);
check_node(indexes.size() == 1, node);
return indexes[0];
}
bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const { bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const {
if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index)) { if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index)) {
return true; return true;
@ -218,20 +256,25 @@ bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) co
} }
std::vector<std::string> onnx_editor::EdgeMapper::get_input_ports(const EditorNode& node) const { std::vector<std::string> onnx_editor::EdgeMapper::get_input_ports(const EditorNode& node) const {
NGRAPH_CHECK(is_correct_and_unambiguous_node(node), check_node(is_correct_and_unambiguous_node(node), node);
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + auto node_index = node.m_node_index;
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + if (node_index == -1) { // the node index is provided
" is ambiguous"); node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0]; } else {
check_node_index(node_index);
}
return m_node_inputs[node_index]; return m_node_inputs[node_index];
} }
std::vector<std::string> onnx_editor::EdgeMapper::get_output_ports(const EditorNode& node) const { std::vector<std::string> onnx_editor::EdgeMapper::get_output_ports(const EditorNode& node) const {
NGRAPH_CHECK(is_correct_and_unambiguous_node(node), check_node(is_correct_and_unambiguous_node(node), node);
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + auto node_index = node.m_node_index;
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + if (node_index == -1) // the node index is provided
" is ambiguous"); {
const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0]; node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
} else {
check_node_index(node_index);
}
return m_node_outputs[node_index]; return m_node_outputs[node_index];
} }
@ -250,3 +293,8 @@ std::string onnx_editor::EdgeMapper::get_target_tensor_name(const OutputEdge& ed
} }
return ""; return "";
} }
void onnx_editor::EdgeMapper::check_node_index(int node_index) const {
NGRAPH_CHECK(node_index >= 0 && node_index < static_cast<int>(m_node_inputs.size()),
"Provided node index: " + std::to_string(node_index) + " is out of scope");
}

View File

@ -92,6 +92,16 @@ public:
/// ///
bool is_correct_and_unambiguous_node(const EditorNode& node) const; bool is_correct_and_unambiguous_node(const EditorNode& node) const;
/// \brief Returns index (position) of provided node in the graph
/// in topological order.
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
/// \note The exception will be thrown if the provided node is ambiguous.
///
int get_node_index(const EditorNode& node) const;
/// \brief Returns true if a provided tensor name is correct (exists in a graph). /// \brief Returns true if a provided tensor name is correct (exists in a graph).
/// ///
/// \param name The name of tensor in a graph. /// \param name The name of tensor in a graph.
@ -130,6 +140,7 @@ private:
// note: a single node can have more than one inputs with the same name // note: a single node can have more than one inputs with the same name
std::vector<int> get_node_input_indexes(int node_index, const std::string& input_name) const; std::vector<int> get_node_input_indexes(int node_index, const std::string& input_name) const;
int get_node_output_idx(int node_index, const std::string& output_name) const; int get_node_output_idx(int node_index, const std::string& output_name) const;
void check_node_index(int node_index) const;
std::vector<std::vector<std::string>> m_node_inputs; std::vector<std::vector<std::string>> m_node_inputs;
std::vector<std::vector<std::string>> m_node_outputs; std::vector<std::vector<std::string>> m_node_outputs;

View File

@ -444,6 +444,11 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node); return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
} }
int onnx_editor::ONNXModelEditor::get_node_index(const EditorNode& node) const {
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.get_node_index(node);
}
bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const { bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const {
update_mapper_if_needed(); update_mapper_if_needed();
return m_pimpl->m_edge_mapper.is_correct_tensor_name(name); return m_pimpl->m_edge_mapper.is_correct_tensor_name(name);

View File

@ -198,6 +198,16 @@ public:
/// ///
bool is_correct_and_unambiguous_node(const EditorNode& node) const; bool is_correct_and_unambiguous_node(const EditorNode& node) const;
/// \brief Returns index (position) of provided node in the graph
/// in topological order.
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
/// \note The exception will be thrown if the provided node is ambiguous.
///
int get_node_index(const EditorNode& node) const;
/// \brief Returns true if a provided tensor name is correct (exists in a graph). /// \brief Returns true if a provided tensor name is correct (exists in a graph).
/// ///
/// \param name The name of tensor in a graph. /// \param name The name of tensor in a graph.

View File

@ -114,11 +114,14 @@ struct EditorOutput {
/// You can indicate test_node by name as EditorNode("test_node") /// You can indicate test_node by name as EditorNode("test_node")
/// or by assigned output as EditorNode(EditorOutput("out1")) /// or by assigned output as EditorNode(EditorOutput("out1"))
/// or EditorNode(EditorOutput("out2")) /// or EditorNode(EditorOutput("out2"))
/// or you can determine the node by postition of a node in an ONNX graph (in topological order).
struct EditorNode { struct EditorNode {
EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {} EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {}
EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {} EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {}
EditorNode(const int node_index) : m_node_index{node_index} {}
const std::string m_node_name = ""; const std::string m_node_name = "";
const std::string m_output_name = ""; const std::string m_output_name = "";
const int m_node_index = -1;
}; };
} // namespace onnx_editor } // namespace onnx_editor
} // namespace ngraph } // namespace ngraph

View File

@ -47,6 +47,18 @@ Place::Ptr PlaceInputEdgeONNX::get_source_tensor() const {
return std::make_shared<PlaceTensorONNX>(m_editor->get_source_tensor_name(m_edge), m_editor); return std::make_shared<PlaceTensorONNX>(m_editor->get_source_tensor_name(m_edge), m_editor);
} }
std::vector<Place::Ptr> PlaceInputEdgeONNX::get_consuming_operations() const {
return {std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor)};
}
Place::Ptr PlaceInputEdgeONNX::get_producing_operation() const {
return get_source_tensor()->get_producing_operation();
}
Place::Ptr PlaceInputEdgeONNX::get_producing_port() const {
return get_source_tensor()->get_producing_port();
}
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor) std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge}, : m_edge{edge},
@ -85,6 +97,18 @@ Place::Ptr PlaceOutputEdgeONNX::get_target_tensor() const {
return std::make_shared<PlaceTensorONNX>(m_editor->get_target_tensor_name(m_edge), m_editor); return std::make_shared<PlaceTensorONNX>(m_editor->get_target_tensor_name(m_edge), m_editor);
} }
std::vector<Place::Ptr> PlaceOutputEdgeONNX::get_consuming_ports() const {
return get_target_tensor()->get_consuming_ports();
}
Place::Ptr PlaceOutputEdgeONNX::get_producing_operation() const {
return std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor);
}
std::vector<Place::Ptr> PlaceOutputEdgeONNX::get_consuming_operations() const {
return get_target_tensor()->get_consuming_operations();
}
PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor) PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_name{name}, : m_name{name},
m_editor{std::move(editor)} {} m_editor{std::move(editor)} {}
@ -112,10 +136,8 @@ std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_ports() const {
return ret; return ret;
} }
Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const { Place::Ptr PlaceTensorONNX::get_producing_operation() const {
return std::make_shared<PlaceInputEdgeONNX>( return get_producing_port()->get_producing_operation();
m_editor->find_input_edge(onnx_editor::EditorOutput(m_name), onnx_editor::EditorInput(input_port_index)),
m_editor);
} }
bool PlaceTensorONNX::is_input() const { bool PlaceTensorONNX::is_input() const {
@ -146,6 +168,19 @@ bool PlaceTensorONNX::is_equal_data(Place::Ptr another) const {
eq_to_consuming_port(another); eq_to_consuming_port(another);
} }
std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_operations() const {
std::vector<Place::Ptr> consuming_ports = get_consuming_ports();
std::vector<Place::Ptr> consuming_ops;
std::transform(std::begin(consuming_ports),
std::end(consuming_ports),
std::back_inserter(consuming_ops),
[](const Place::Ptr& place) {
return place->get_consuming_operations().at(0);
});
return consuming_ops;
}
PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor) PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_node{node}, : m_node{node},
m_editor{std::move(editor)} {} m_editor{std::move(editor)} {}
@ -158,6 +193,10 @@ std::vector<std::string> PlaceOpONNX::get_names() const {
return {m_node.m_node_name}; return {m_node.m_node_name};
} }
onnx_editor::EditorNode PlaceOpONNX::get_editor_node() const {
return m_node;
}
Place::Ptr PlaceOpONNX::get_output_port() const { Place::Ptr PlaceOpONNX::get_output_port() const {
if (m_editor->get_output_ports(m_node).size() == 1) { if (m_editor->get_output_ports(m_node).size() == 1) {
return get_output_port(0); return get_output_port(0);
@ -209,3 +248,133 @@ Place::Ptr PlaceOpONNX::get_input_port(const std::string& input_name) const {
} }
return nullptr; return nullptr;
} }
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_ports() const {
std::vector<Place::Ptr> consuming_ports;
const auto out_ports_size = m_editor->get_output_ports(m_node).size();
for (int out_idx = 0; out_idx < out_ports_size; ++out_idx) {
auto consuming_ops_out = get_output_port(out_idx)->get_consuming_ports();
consuming_ports.insert(consuming_ports.end(), consuming_ops_out.begin(), consuming_ops_out.end());
}
return consuming_ports;
}
namespace {
std::vector<Place::Ptr> get_consuming_ops(std::vector<Place::Ptr> input_ports) {
std::vector<Place::Ptr> consuming_ops;
std::transform(std::begin(input_ports),
std::end(input_ports),
std::back_inserter(consuming_ops),
[](const Place::Ptr place) {
return place->get_consuming_operations().at(0);
});
return consuming_ops;
}
} // namespace
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations() const {
std::vector<Place::Ptr> consuming_ports = get_consuming_ports();
return get_consuming_ops(consuming_ports);
}
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations(int output_port_index) const {
std::vector<Place::Ptr> consuming_ports = get_output_port(output_port_index)->get_consuming_ports();
return get_consuming_ops(consuming_ports);
}
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations(const std::string& output_port_name) const {
std::vector<Place::Ptr> consuming_ports = get_output_port(output_port_name)->get_consuming_ports();
return get_consuming_ops(consuming_ports);
}
Place::Ptr PlaceOpONNX::get_producing_operation() const {
const auto input_port = get_input_port();
if (input_port != nullptr) {
return input_port->get_producing_operation();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_producing_operation(int input_port_index) const {
const auto input_port = get_input_port(input_port_index);
if (input_port != nullptr) {
return input_port->get_producing_operation();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_producing_operation(const std::string& input_port_name) const {
const auto input_port = get_input_port(input_port_name);
if (input_port != nullptr) {
return input_port->get_producing_operation();
}
return nullptr;
}
bool PlaceOpONNX::is_equal(Place::Ptr another) const {
if (const auto place_op = std::dynamic_pointer_cast<PlaceOpONNX>(another)) {
const auto& another_node = place_op->get_editor_node();
if (m_editor->is_correct_and_unambiguous_node(m_node) ||
m_editor->is_correct_and_unambiguous_node(another_node)) {
return m_editor->get_node_index(m_node) == m_editor->get_node_index(another_node);
}
}
return false;
}
Place::Ptr PlaceOpONNX::get_target_tensor() const {
const auto output_port = get_output_port();
if (output_port != nullptr) {
return output_port->get_target_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_target_tensor(int output_port_index) const {
const auto output_port = get_output_port(output_port_index);
if (output_port != nullptr) {
return output_port->get_target_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_target_tensor(const std::string& output_name) const {
const auto output_port = get_output_port(output_name);
if (output_port != nullptr) {
return output_port->get_target_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_source_tensor() const {
const auto input_port = get_input_port();
if (input_port != nullptr) {
return input_port->get_source_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_source_tensor(int input_port_index) const {
const auto input_port = get_input_port(input_port_index);
if (input_port != nullptr) {
return input_port->get_source_tensor();
}
return nullptr;
}
Place::Ptr PlaceOpONNX::get_source_tensor(const std::string& input_name) const {
const auto input_port = get_input_port(input_name);
if (input_port != nullptr) {
return input_port->get_source_tensor();
}
return nullptr;
}
bool PlaceOpONNX::is_input() const {
return false;
}
bool PlaceOpONNX::is_output() const {
return false;
}

View File

@ -7,6 +7,7 @@
#include <editor.hpp> #include <editor.hpp>
#include <frontend_manager/place.hpp> #include <frontend_manager/place.hpp>
#include <memory> #include <memory>
#include <sstream>
namespace ngraph { namespace ngraph {
namespace frontend { namespace frontend {
@ -15,17 +16,18 @@ public:
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
PlaceInputEdgeONNX(onnx_editor::InputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceInputEdgeONNX(onnx_editor::InputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
// internal usage
onnx_editor::InputEdge get_input_edge() const; onnx_editor::InputEdge get_input_edge() const;
// external usage
bool is_input() const override; bool is_input() const override;
bool is_output() const override; bool is_output() const override;
bool is_equal(Place::Ptr another) const override; bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override; bool is_equal_data(Place::Ptr another) const override;
Place::Ptr get_source_tensor() const override; Place::Ptr get_source_tensor() const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
Place::Ptr get_producing_operation() const override;
Place::Ptr get_producing_port() const override;
private: private:
onnx_editor::InputEdge m_edge; onnx_editor::InputEdge m_edge;
@ -37,17 +39,18 @@ public:
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
PlaceOutputEdgeONNX(onnx_editor::OutputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceOutputEdgeONNX(onnx_editor::OutputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
// internal usage
onnx_editor::OutputEdge get_output_edge() const; onnx_editor::OutputEdge get_output_edge() const;
// external usage
bool is_input() const override; bool is_input() const override;
bool is_output() const override; bool is_output() const override;
bool is_equal(Place::Ptr another) const override; bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override; bool is_equal_data(Place::Ptr another) const override;
Place::Ptr get_target_tensor() const override; Place::Ptr get_target_tensor() const override;
std::vector<Place::Ptr> get_consuming_ports() const override;
Place::Ptr get_producing_operation() const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
private: private:
onnx_editor::OutputEdge m_edge; onnx_editor::OutputEdge m_edge;
@ -59,21 +62,16 @@ public:
PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
PlaceTensorONNX(std::string&& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceTensorONNX(std::string&& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
// external usage
std::vector<std::string> get_names() const override; std::vector<std::string> get_names() const override;
Place::Ptr get_producing_port() const override; Place::Ptr get_producing_port() const override;
std::vector<Place::Ptr> get_consuming_ports() const override; std::vector<Place::Ptr> get_consuming_ports() const override;
Place::Ptr get_producing_operation() const override;
Ptr get_input_port(int input_port_index) const override;
bool is_input() const override; bool is_input() const override;
bool is_output() const override; bool is_output() const override;
bool is_equal(Place::Ptr another) const override; bool is_equal(Place::Ptr another) const override;
bool is_equal_data(Place::Ptr another) const override; bool is_equal_data(Place::Ptr another) const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
private: private:
std::string m_name; std::string m_name;
@ -86,6 +84,10 @@ public:
PlaceOpONNX(onnx_editor::EditorNode&& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor); PlaceOpONNX(onnx_editor::EditorNode&& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
std::vector<std::string> get_names() const override; std::vector<std::string> get_names() const override;
// internal usage
onnx_editor::EditorNode get_editor_node() const;
// external usage
Place::Ptr get_output_port() const override; Place::Ptr get_output_port() const override;
Place::Ptr get_output_port(int output_port_index) const override; Place::Ptr get_output_port(int output_port_index) const override;
Place::Ptr get_output_port(const std::string& output_port_name) const override; Place::Ptr get_output_port(const std::string& output_port_name) const override;
@ -94,6 +96,27 @@ public:
Place::Ptr get_input_port(int input_port_index) const override; Place::Ptr get_input_port(int input_port_index) const override;
Place::Ptr get_input_port(const std::string& input_name) const override; Place::Ptr get_input_port(const std::string& input_name) const override;
std::vector<Place::Ptr> get_consuming_ports() const override;
std::vector<Place::Ptr> get_consuming_operations() const override;
std::vector<Place::Ptr> get_consuming_operations(int output_port_index) const override;
std::vector<Place::Ptr> get_consuming_operations(const std::string& output_port_name) const override;
Place::Ptr get_producing_operation() const override;
Place::Ptr get_producing_operation(int input_port_index) const override;
Place::Ptr get_producing_operation(const std::string& input_port_name) const override;
Place::Place::Ptr get_target_tensor() const override;
Place::Ptr get_target_tensor(int output_port_index) const override;
Place::Ptr get_target_tensor(const std::string& output_name) const override;
Place::Place::Ptr get_source_tensor() const override;
Place::Ptr get_source_tensor(int input_port_index) const override;
Place::Ptr get_source_tensor(const std::string& input_name) const override;
bool is_equal(Place::Ptr another) const override;
bool is_input() const override;
bool is_output() const override;
private: private:
onnx_editor::EditorNode m_node; onnx_editor::EditorNode m_node;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor; std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;

View File

@ -693,6 +693,26 @@ NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_ind
EXPECT_EQ(edge2.m_new_input_name, "custom_input_name_2"); EXPECT_EQ(edge2.m_new_input_name, "custom_input_name_2");
} }
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_index) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
const InputEdge edge = editor.find_input_edge(EditorNode{0}, EditorInput{0, "custom_input_name_1"});
EXPECT_EQ(edge.m_node_idx, 0);
EXPECT_EQ(edge.m_port_idx, 0);
EXPECT_EQ(edge.m_new_input_name, "custom_input_name_1");
const InputEdge edge2 = editor.find_input_edge(EditorNode{5}, EditorInput{0});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_port_idx, 0);
try {
editor.find_input_edge(EditorNode{99}, EditorInput{"conv1/7x7_s2_1"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name) { NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")}; ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
@ -767,6 +787,21 @@ NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_i
EXPECT_EQ(edge2.m_port_idx, 1); EXPECT_EQ(edge2.m_port_idx, 1);
} }
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_index) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
const OutputEdge edge = editor.find_output_edge(EditorNode{5}, EditorOutput{1});
EXPECT_EQ(edge.m_node_idx, 5);
EXPECT_EQ(edge.m_port_idx, 1);
try {
editor.find_output_edge(EditorNode{99}, EditorOutput{"conv1/7x7_s2_1"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network) { NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.onnx")}; ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.onnx")};
@ -1044,6 +1079,12 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"}); is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"});
EXPECT_EQ(is_correct_node, true); EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{2});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{99});
EXPECT_EQ(is_correct_node, false);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}}); is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}});
EXPECT_EQ(is_correct_node, false); EXPECT_EQ(is_correct_node, false);
@ -1054,6 +1095,32 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
EXPECT_EQ(is_correct_node, false); EXPECT_EQ(is_correct_node, false);
} }
NGRAPH_TEST(onnx_editor, editor_api_get_node_index) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
EXPECT_EQ(editor.get_node_index(EditorNode{2}), 2);
EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"relu1"}}), 0);
EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"split2"}}), 5);
EXPECT_EQ(editor.get_node_index(EditorNode{"relu1_name"}), 0);
try {
editor.get_node_index(EditorNode{99});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
}
try {
editor.get_node_index(EditorNode{"add_ambiguous_name"});
} catch (const std::exception& e) {
std::string msg{e.what()};
EXPECT_TRUE(
msg.find(
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_input_edge_from_tensor_with_single_consumer) { NGRAPH_TEST(onnx_editor, editor_api_input_edge_from_tensor_with_single_consumer) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")}; ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")};
@ -1365,15 +1432,18 @@ NGRAPH_TEST(onnx_editor, get_input_ports) {
editor.get_input_ports(EditorNode{"add_ambiguous_name"}); editor.get_input_ports(EditorNode{"add_ambiguous_name"});
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::string msg{e.what()}; std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") != EXPECT_TRUE(
std::string::npos); msg.find(
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
} }
try { try {
editor.get_input_ports(EditorNode{""}); editor.get_input_ports(EditorNode{""});
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::string msg{e.what()}; std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") != EXPECT_TRUE(
std::string::npos); msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
} }
} }
NGRAPH_TEST(onnx_editor, get_output_ports) { NGRAPH_TEST(onnx_editor, get_output_ports) {
@ -1393,14 +1463,17 @@ NGRAPH_TEST(onnx_editor, get_output_ports) {
editor.get_output_ports(EditorNode{"add_ambiguous_name"}); editor.get_output_ports(EditorNode{"add_ambiguous_name"});
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::string msg{e.what()}; std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") != EXPECT_TRUE(
std::string::npos); msg.find(
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
} }
try { try {
editor.get_output_ports(EditorNode{""}); editor.get_output_ports(EditorNode{""});
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::string msg{e.what()}; std::string msg{e.what()};
EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") != EXPECT_TRUE(
std::string::npos); msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") !=
std::string::npos);
} }
} }

View File

@ -89,13 +89,23 @@ void regclass_pyngraph_Place(py::module m) {
place.def( place.def(
"get_consuming_operations", "get_consuming_operations",
[](const ngraph::frontend::Place& self, py::object outputPortIndex) { [](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) {
if (outputPortIndex == py::none()) { if (outputName == py::none()) {
return self.get_consuming_operations(); if (outputPortIndex == py::none()) {
return self.get_consuming_operations();
} else {
return self.get_consuming_operations(py::cast<int>(outputPortIndex));
}
} else { } else {
return self.get_consuming_operations(py::cast<int>(outputPortIndex)); if (outputPortIndex == py::none()) {
return self.get_consuming_operations(py::cast<std::string>(outputName));
} else {
return self.get_consuming_operations(py::cast<std::string>(outputName),
py::cast<int>(outputPortIndex));
}
} }
}, },
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(), py::arg("outputPortIndex") = py::none(),
R"( R"(
Returns references to all operation nodes that consume data from this place for specified output port. Returns references to all operation nodes that consume data from this place for specified output port.
@ -103,6 +113,8 @@ void regclass_pyngraph_Place(py::module m) {
Parameters Parameters
---------- ----------
outputName : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int outputPortIndex : int
If place is an operational node it specifies which output port should be considered If place is an operational node it specifies which output port should be considered
May not be set if node has only one output port. May not be set if node has only one output port.
@ -115,13 +127,22 @@ void regclass_pyngraph_Place(py::module m) {
place.def( place.def(
"get_target_tensor", "get_target_tensor",
[](const ngraph::frontend::Place& self, py::object outputPortIndex) { [](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) {
if (outputPortIndex == py::none()) { if (outputName == py::none()) {
return self.get_target_tensor(); if (outputPortIndex == py::none()) {
return self.get_target_tensor();
} else {
return self.get_target_tensor(py::cast<int>(outputPortIndex));
}
} else { } else {
return self.get_target_tensor(py::cast<int>(outputPortIndex)); if (outputPortIndex == py::none()) {
return self.get_target_tensor(py::cast<std::string>(outputName));
} else {
return self.get_target_tensor(py::cast<std::string>(outputName), py::cast<int>(outputPortIndex));
}
} }
}, },
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(), py::arg("outputPortIndex") = py::none(),
R"( R"(
Returns a tensor place that gets data from this place; applicable for operations, Returns a tensor place that gets data from this place; applicable for operations,
@ -129,6 +150,8 @@ void regclass_pyngraph_Place(py::module m) {
Parameters Parameters
---------- ----------
outputName : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int outputPortIndex : int
Output port index if the current place is an operation node and has multiple output ports. Output port index if the current place is an operation node and has multiple output ports.
May not be set if place has only one output port. May not be set if place has only one output port.
@ -141,19 +164,31 @@ void regclass_pyngraph_Place(py::module m) {
place.def( place.def(
"get_producing_operation", "get_producing_operation",
[](const ngraph::frontend::Place& self, py::object inputPortIndex) { [](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) {
if (inputPortIndex == py::none()) { if (inputName == py::none()) {
return self.get_producing_operation(); if (inputPortIndex == py::none()) {
return self.get_producing_operation();
} else {
return self.get_producing_operation(py::cast<int>(inputPortIndex));
}
} else { } else {
return self.get_producing_operation(py::cast<int>(inputPortIndex)); if (inputPortIndex == py::none()) {
return self.get_producing_operation(py::cast<std::string>(inputName));
} else {
return self.get_producing_operation(py::cast<std::string>(inputName),
py::cast<int>(inputPortIndex));
}
} }
}, },
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(), py::arg("inputPortIndex") = py::none(),
R"( R"(
Get an operation node place that immediately produces data for this place. Get an operation node place that immediately produces data for this place.
Parameters Parameters
---------- ----------
inputName : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int inputPortIndex : int
If a given place is itself an operation node, this specifies a port index. If a given place is itself an operation node, this specifies a port index.
May not be set if place has only one input port. May not be set if place has only one input port.
@ -260,13 +295,22 @@ void regclass_pyngraph_Place(py::module m) {
place.def( place.def(
"get_source_tensor", "get_source_tensor",
[](const ngraph::frontend::Place& self, py::object inputPortIndex) { [](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) {
if (inputPortIndex == py::none()) { if (inputName == py::none()) {
return self.get_source_tensor(); if (inputPortIndex == py::none()) {
return self.get_source_tensor();
} else {
return self.get_source_tensor(py::cast<int>(inputPortIndex));
}
} else { } else {
return self.get_source_tensor(py::cast<int>(inputPortIndex)); if (inputPortIndex == py::none()) {
return self.get_source_tensor(py::cast<std::string>(inputName));
} else {
return self.get_source_tensor(py::cast<std::string>(inputName), py::cast<int>(inputPortIndex));
}
} }
}, },
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(), py::arg("inputPortIndex") = py::none(),
R"( R"(
Returns a tensor place that supplies data for this place; applicable for operations, Returns a tensor place that supplies data for this place; applicable for operations,
@ -274,6 +318,8 @@ void regclass_pyngraph_Place(py::module m) {
Parameters Parameters
---------- ----------
inputName : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int inputPortIndex : int
Input port index for operational node. May not be specified if place has only one input port. Input port index for operational node. May not be specified if place has only one input port.

View File

@ -107,12 +107,29 @@ public:
std::vector<Place::Ptr> get_consuming_operations() const override { std::vector<Place::Ptr> get_consuming_operations() const override {
m_stat.m_get_consuming_operations++; m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = -1; m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = "";
return {std::make_shared<PlaceMockPy>()}; return {std::make_shared<PlaceMockPy>()};
} }
std::vector<Place::Ptr> get_consuming_operations(int outputPortIndex) const override { std::vector<Place::Ptr> get_consuming_operations(int outputPortIndex) const override {
m_stat.m_get_consuming_operations++; m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = outputPortIndex; m_stat.m_lastArgInt = outputPortIndex;
m_stat.m_lastArgString = "";
return {std::make_shared<PlaceMockPy>()};
}
std::vector<Place::Ptr> get_consuming_operations(const std::string& outputName) const override {
m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
std::vector<Place::Ptr> get_consuming_operations(const std::string& outputName,
int outputPortIndex) const override {
m_stat.m_get_consuming_operations++;
m_stat.m_lastArgInt = outputPortIndex;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()}; return {std::make_shared<PlaceMockPy>()};
} }
@ -128,6 +145,20 @@ public:
return std::make_shared<PlaceMockPy>(); return std::make_shared<PlaceMockPy>();
} }
Place::Ptr get_target_tensor(const std::string& outputName) const override {
m_stat.m_get_target_tensor++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const override {
m_stat.m_get_target_tensor++;
m_stat.m_lastArgInt = outputPortIndex;
m_stat.m_lastArgString = outputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_producing_operation() const override { Place::Ptr get_producing_operation() const override {
m_stat.m_get_producing_operation++; m_stat.m_get_producing_operation++;
m_stat.m_lastArgInt = -1; m_stat.m_lastArgInt = -1;
@ -140,6 +171,20 @@ public:
return std::make_shared<PlaceMockPy>(); return std::make_shared<PlaceMockPy>();
} }
Place::Ptr get_producing_operation(const std::string& inputName) const override {
m_stat.m_get_producing_operation++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_producing_operation(const std::string& inputName, int inputPortIndex) const override {
m_stat.m_get_producing_operation++;
m_stat.m_lastArgInt = inputPortIndex;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_producing_port() const override { Place::Ptr get_producing_port() const override {
m_stat.m_get_producing_port++; m_stat.m_get_producing_port++;
return std::make_shared<PlaceMockPy>(); return std::make_shared<PlaceMockPy>();
@ -236,6 +281,20 @@ public:
return {std::make_shared<PlaceMockPy>()}; return {std::make_shared<PlaceMockPy>()};
} }
Place::Ptr get_source_tensor(const std::string& inputName) const override {
m_stat.m_get_source_tensor++;
m_stat.m_lastArgInt = -1;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
Place::Ptr get_source_tensor(const std::string& inputName, int inputPortIndex) const override {
m_stat.m_get_source_tensor++;
m_stat.m_lastArgInt = inputPortIndex;
m_stat.m_lastArgString = inputName;
return {std::make_shared<PlaceMockPy>()};
}
//---------------Stat-------------------- //---------------Stat--------------------
PlaceStat get_stat() const { PlaceStat get_stat() const {
return m_stat; return m_stat;

View File

@ -9,6 +9,7 @@ from ngraph import PartialShape
from ngraph.frontend import FrontEndManager from ngraph.frontend import FrontEndManager
# ------Test input model 1------
# in1 in2 in3 # in1 in2 in3
# | | | # | | |
# \ / | # \ / |
@ -24,9 +25,32 @@ from ngraph.frontend import FrontEndManager
# / \ | # / \ |
# out1 out2 out4 # out1 out2 out4
# #
#
# ------Test input model 2------
# in1 in2
# | |
# \ /
# +--------+
# | Add |
# +--------+
# <add_out>
# |
# +--------+
# | Split |
# |(split2)|
# +--------+
# / \
# <sp_out1> <sp_out2>
# +-------+ +-------+
# | Abs | | Sin |
# | (abs1)| | |
# +------ + +-------+
# | |
# out1 out2
#
def create_test_onnx_models(): def create_test_onnx_models():
models = {} models = {}
# Input model # Input model 1
add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"]) add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"])
split = onnx.helper.make_node("Split", inputs=["add_out"], split = onnx.helper.make_node("Split", inputs=["add_out"],
outputs=["out1", "out2"], name="split1", axis=0) outputs=["out1", "out2"], name="split1", axis=0)
@ -48,6 +72,24 @@ def create_test_onnx_models():
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer", models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)]) opset_imports=[onnx.helper.make_opsetid("", 13)])
# Input model 2
split_2 = onnx.helper.make_node("Split", inputs=["add_out"],
outputs=["sp_out1", "sp_out2"], name="split2", axis=0)
abs = onnx.helper.make_node("Abs", inputs=["sp_out1"], outputs=["out1"], name="abs1")
sin = onnx.helper.make_node("Sin", inputs=["sp_out2"], outputs=["out2"])
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
]
graph = make_graph([add, split_2, abs, sin], "test_graph_2", input_tensors, output_tensors)
models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph # Expected for extract_subgraph
input_tensors = [ input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
@ -312,8 +354,9 @@ def test_extract_subgraph_4():
model = fe.load("input_model.onnx") model = fe.load("input_model.onnx")
assert model assert model
place1 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1) place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="out1") place4 = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_tensor_name(tensorName="out2") place5 = model.get_place_by_tensor_name(tensorName="out2")
@ -377,8 +420,9 @@ def test_override_all_inputs():
place1 = model.get_place_by_operation_name_and_input_port( place1 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0) operationName="split1", inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1) place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place4 = model.get_place_by_tensor_name(tensorName="in3") place4 = model.get_place_by_tensor_name(tensorName="in3")
model.override_all_inputs(inputs=[place1, place2, place3, place4]) model.override_all_inputs(inputs=[place1, place2, place3, place4])
result_func = fe.convert(model) result_func = fe.convert(model)
@ -432,11 +476,15 @@ def test_is_input_output():
assert not place3.is_input() assert not place3.is_input()
assert not place3.is_output() assert not place3.is_output()
place4 = place1 = model.get_place_by_operation_name_and_input_port( place4 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0) operationName="split1", inputPortIndex=0)
assert not place4.is_input() assert not place4.is_input()
assert not place4.is_output() assert not place4.is_output()
place5 = model.get_place_by_operation_name(operationName="split1")
assert not place5.is_input()
assert not place5.is_output()
def test_set_partial_shape(): def test_set_partial_shape():
skip_if_onnx_frontend_is_disabled() skip_if_onnx_frontend_is_disabled()
@ -520,12 +568,14 @@ def test_is_equal():
place2 = model.get_place_by_tensor_name(tensorName="out2") place2 = model.get_place_by_tensor_name(tensorName="out2")
assert place2.is_equal(place2) assert place2.is_equal(place2)
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place4 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert place3.is_equal(place4) assert place3.is_equal(place4)
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place6 = model.get_place_by_tensor_name(tensorName="out1").get_input_port(inputPortIndex=0) place6 = out1_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert place5.is_equal(place6) assert place5.is_equal(place6)
place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port() place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port()
@ -538,6 +588,10 @@ def test_is_equal():
assert not place6.is_equal(place7) assert not place6.is_equal(place7)
assert not place8.is_equal(place2) assert not place8.is_equal(place2)
place9 = model.get_place_by_operation_name(operationName="split1")
assert place2.get_producing_operation().is_equal(place9)
assert not place9.is_equal(place2)
def test_is_equal_data(): def test_is_equal_data():
skip_if_onnx_frontend_is_disabled() skip_if_onnx_frontend_is_disabled()
@ -560,11 +614,12 @@ def test_is_equal_data():
place4 = place2.get_producing_port() place4 = place2.get_producing_port()
assert place2.is_equal_data(place4) assert place2.is_equal_data(place4)
place5 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place5 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert place2.is_equal_data(place5) assert place2.is_equal_data(place5)
assert place4.is_equal_data(place5) assert place4.is_equal_data(place5)
place6 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1) place6 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
assert place6.is_equal_data(place5) assert place6.is_equal_data(place5)
place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
@ -648,3 +703,198 @@ def test_get_input_port():
assert not split_op.get_input_port(inputPortIndex=1) assert not split_op.get_input_port(inputPortIndex=1)
assert not split_op.get_input_port(inputName="not_existed") assert not split_op.get_input_port(inputName="not_existed")
def test_get_consuming_ports():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
add_tensor_consuming_ports = place1.get_consuming_ports()
assert len(add_tensor_consuming_ports) == 3
place2 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
assert add_tensor_consuming_ports[0].is_equal(place2)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert add_tensor_consuming_ports[1].is_equal(place3)
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
assert add_tensor_consuming_ports[2].is_equal(place4)
add_op_consuming_ports = place1.get_producing_operation().get_consuming_ports()
assert len(add_op_consuming_ports) == len(add_tensor_consuming_ports)
for i in range(len(add_op_consuming_ports)):
assert add_op_consuming_ports[i].is_equal(add_tensor_consuming_ports[i])
def test_get_consuming_ports_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op_consuming_ports = split_op.get_consuming_ports()
assert len(split_op_consuming_ports) == 2
abs_input_port = model.get_place_by_operation_name(operationName="abs1").get_input_port(inputPortIndex=0)
assert split_op_consuming_ports[0].is_equal(abs_input_port)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
sin_input_port = out2_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
assert split_op_consuming_ports[1].is_equal(sin_input_port)
split_out_port_0 = split_op.get_output_port(outputPortIndex=0)
split_out_port_0_consuming_ports = split_out_port_0.get_consuming_ports()
assert len(split_out_port_0_consuming_ports) == 1
assert split_out_port_0_consuming_ports[0].is_equal(abs_input_port)
def test_get_producing_operation():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_tensor_out_2 = model.get_place_by_tensor_name(tensorName="sp_out2")
split_op = model.get_place_by_operation_name(operationName="split2")
assert split_tensor_out_2.get_producing_operation().is_equal(split_op)
split_op = model.get_place_by_operation_name(operationName="split2")
split_out_port_2 = split_op.get_output_port(outputPortIndex=1)
assert split_out_port_2.get_producing_operation().is_equal(split_op)
def test_get_producing_operation_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
abs_op = model.get_place_by_operation_name(operationName="abs1")
abs_port_0 = abs_op.get_input_port()
split_op = model.get_place_by_operation_name(operationName="split2")
assert abs_port_0.get_producing_operation().is_equal(split_op)
assert abs_op.get_producing_operation().is_equal(split_op)
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
assert not add_op.get_producing_operation()
split_op_producing_op = split_op.get_producing_operation(inputName="add_out")
assert split_op_producing_op.is_equal(add_op)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
sin_op = out2_tensor.get_producing_operation()
assert sin_op.get_producing_operation(inputPortIndex=0).is_equal(split_op)
def test_get_consuming_operations():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op_consuming_ops = split_op.get_consuming_operations()
abs_op = model.get_place_by_operation_name(operationName="abs1")
sin_op = model.get_place_by_tensor_name(tensorName="out2").get_producing_operation()
assert len(split_op_consuming_ops) == 2
assert split_op_consuming_ops[0].is_equal(abs_op)
assert split_op_consuming_ops[1].is_equal(sin_op)
split_op_port = split_op.get_input_port(inputPortIndex=0)
split_op_port_consuming_ops = split_op_port.get_consuming_operations()
assert len(split_op_port_consuming_ops) == 1
assert split_op_port_consuming_ops[0].is_equal(split_op)
add_out_port = model.get_place_by_tensor_name(tensorName="add_out").get_producing_port()
add_out_port_consuming_ops = add_out_port.get_consuming_operations()
assert len(add_out_port_consuming_ops) == 1
assert add_out_port_consuming_ops[0].is_equal(split_op)
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
sp_out2_tensor_consuming_ops = sp_out2_tensor.get_consuming_operations()
assert len(sp_out2_tensor_consuming_ops) == 1
assert sp_out2_tensor_consuming_ops[0].is_equal(sin_op)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
out2_tensor_consuming_ops = out2_tensor.get_consuming_operations()
assert len(out2_tensor_consuming_ops) == 0
out2_port_consuming_ops = out2_tensor.get_producing_port().get_consuming_operations()
assert len(out2_port_consuming_ops) == 0
split_out_1_consuming_ops = split_op.get_consuming_operations(outputPortIndex=1)
assert len(split_out_1_consuming_ops) == 1
split_out_sp_out_2_consuming_ops = split_op.get_consuming_operations(outputName="sp_out2")
assert len(split_out_sp_out_2_consuming_ops) == 1
assert split_out_1_consuming_ops[0].is_equal(split_out_sp_out_2_consuming_ops[0])
assert split_out_1_consuming_ops[0].is_equal(sin_op)
def test_get_target_tensor():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
assert not split_op.get_target_tensor()
split_op_tensor_1 = split_op.get_target_tensor(outputPortIndex=1)
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
assert split_op_tensor_1.is_equal(sp_out2_tensor)
split_tensor_sp_out2 = split_op.get_target_tensor(outputName="sp_out2")
assert split_tensor_sp_out2.is_equal(split_op_tensor_1)
abs_op = model.get_place_by_operation_name(operationName="abs1")
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
assert abs_op.get_target_tensor().is_equal(out1_tensor)
def test_get_source_tensor():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
assert not add_op.get_source_tensor()
add_op_in_tensor_1 = add_op.get_source_tensor(inputPortIndex=1)
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
assert add_op_in_tensor_1.is_equal(in2_tensor)
add_op_in_tensor_in2 = add_op.get_source_tensor(inputName="in2")
assert add_op_in_tensor_in2.is_equal(in2_tensor)
split_op = model.get_place_by_operation_name(operationName="split2")
assert split_op.get_source_tensor().is_equal(add_out_tensor)
def test_get_producing_port():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op_in_port = split_op.get_input_port()
split_op_in_port_prod_port = split_op_in_port.get_producing_port()
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
add_op_out_port = add_op.get_output_port()
assert split_op_in_port_prod_port.is_equal(add_op_out_port)

View File

@ -440,6 +440,16 @@ def test_place_get_consuming_operations():
stat = get_place_stat(place) stat = get_place_stat(place)
assert stat.get_consuming_operations == 2 assert stat.get_consuming_operations == 2
assert stat.lastArgInt == -1 assert stat.lastArgInt == -1
assert place.get_consuming_operations(outputName="2") is not None
stat = get_place_stat(place)
assert stat.get_consuming_operations == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_consuming_operations(outputName="3", outputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_consuming_operations == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"
@mock_needed @mock_needed
@ -453,6 +463,16 @@ def test_place_get_target_tensor():
stat = get_place_stat(place) stat = get_place_stat(place)
assert stat.get_target_tensor == 2 assert stat.get_target_tensor == 2
assert stat.lastArgInt == -1 assert stat.lastArgInt == -1
assert place.get_target_tensor(outputName="2") is not None
stat = get_place_stat(place)
assert stat.get_target_tensor == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_target_tensor(outputName="3", outputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_target_tensor == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"
@mock_needed @mock_needed
@ -466,6 +486,16 @@ def test_place_get_producing_operation():
stat = get_place_stat(place) stat = get_place_stat(place)
assert stat.get_producing_operation == 2 assert stat.get_producing_operation == 2
assert stat.lastArgInt == -1 assert stat.lastArgInt == -1
assert place.get_producing_operation(inputName="2") is not None
stat = get_place_stat(place)
assert stat.get_producing_operation == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_producing_operation(inputName="3", inputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_producing_operation == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"
@mock_needed @mock_needed
@ -551,3 +581,13 @@ def test_place_get_source_tensor():
stat = get_place_stat(place) stat = get_place_stat(place)
assert stat.get_source_tensor == 2 assert stat.get_source_tensor == 2
assert stat.lastArgInt == 22 assert stat.lastArgInt == 22
assert place.get_source_tensor(inputName="2") is not None
stat = get_place_stat(place)
assert stat.get_source_tensor == 3
assert stat.lastArgInt == -1
assert stat.lastArgString == "2"
assert place.get_source_tensor(inputName="3", inputPortIndex=33) is not None
stat = get_place_stat(place)
assert stat.get_source_tensor == 4
assert stat.lastArgInt == 33
assert stat.lastArgString == "3"