Implemented missing methods of ONNX Place API (#7518)
* Implemented get_consuming_ports, get_producing_operation and is_equal for PlaceOpONNX * fixed unambiguous_node_check * removed PlaceTensorONNX::get_input_port * added PlaceOpONNX::is_input, PlaceOpONNX::is_output * fixed python styles * added get_consuming_operations implementation * added missing get_consuming_operations for PlaceOpONNX * added missing get_target_tensor for PlaceOpONNX * changed place spec * add support of get_source_tensor * add support of get_producing_operation for PlaceOpONNX * add support of get_producing_port for PlaceInputEdgeONNX * fixed python styles * missing ref in std::transform
This commit is contained in:
parent
d7dfce2091
commit
fb11560b82
@ -103,7 +103,7 @@ public:
|
||||
///
|
||||
/// \return A vector with all operation node references that consumes data from this
|
||||
/// place
|
||||
virtual std::vector<Ptr> get_consuming_operations(const std::string& outputPortName) const;
|
||||
virtual std::vector<Ptr> get_consuming_operations(const std::string& outputName) const;
|
||||
|
||||
/// \brief Returns references to all operation nodes that consume data from this place
|
||||
/// for specified output port
|
||||
@ -127,16 +127,14 @@ public:
|
||||
/// \return A tensor place which hold the resulting value for this place
|
||||
virtual Ptr get_target_tensor() const;
|
||||
|
||||
/// \brief Returns a tensor place that gets data from this place; applicable for
|
||||
/// operations, output ports and output edges which have only one output port
|
||||
/// \brief Returns a tensor place that gets data from this place; applicable for operations
|
||||
///
|
||||
/// \param outputName Name of output port group
|
||||
///
|
||||
/// \return A tensor place which hold the resulting value for this place
|
||||
virtual Ptr get_target_tensor(const std::string& outputName) const;
|
||||
|
||||
/// \brief Returns a tensor place that gets data from this place; applicable for
|
||||
/// operations, output ports and output edges which have only one output port
|
||||
/// \brief Returns a tensor place that gets data from this place; applicable for operations
|
||||
///
|
||||
/// \param outputName Name of output port group, each group can have multiple ports
|
||||
///
|
||||
@ -146,8 +144,7 @@ public:
|
||||
/// \return A tensor place which hold the resulting value for this place
|
||||
virtual Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const;
|
||||
|
||||
/// \brief Returns a tensor place that gets data from this place; applicable for
|
||||
/// operations, output ports and output edges
|
||||
/// \brief Returns a tensor place that gets data from this place; applicable for operations
|
||||
///
|
||||
/// \param output_port_index Output port index if the current place is an operation node
|
||||
/// and has multiple output ports
|
||||
@ -161,24 +158,21 @@ public:
|
||||
/// \return A tensor place which supplies data for this place
|
||||
virtual Ptr get_source_tensor() const;
|
||||
|
||||
/// \brief Returns a tensor place that supplies data for this place; applicable for
|
||||
/// operations, input ports and input edges
|
||||
/// \brief Returns a tensor place that supplies data for this place; applicable for operations
|
||||
///
|
||||
/// \param input_port_index Input port index for operational nodes.
|
||||
///
|
||||
/// \return A tensor place which supplies data for this place
|
||||
virtual Ptr get_source_tensor(int input_port_index) const;
|
||||
|
||||
/// \brief Returns a tensor place that supplies data for this place; applicable for
|
||||
/// operations, input ports and input edges
|
||||
/// \brief Returns a tensor place that supplies data for this place; applicable for operations
|
||||
///
|
||||
/// \param inputName Name of input port group
|
||||
///
|
||||
/// \return A tensor place which supplies data for this place
|
||||
virtual Ptr get_source_tensor(const std::string& inputName) const;
|
||||
|
||||
/// \brief Returns a tensor place that supplies data for this place; applicable for
|
||||
/// operations, input ports and input edges
|
||||
/// \brief Returns a tensor place that supplies data for this place; applicable for operations
|
||||
///
|
||||
/// \param inputName If a given place is itself an operation node, this specifies name
|
||||
/// of output port group, each group can have multiple ports
|
||||
|
@ -95,37 +95,41 @@ std::vector<int> onnx_editor::EdgeMapper::get_node_input_indexes(int node_index,
|
||||
}
|
||||
|
||||
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const {
|
||||
// identification can be both based on node name and output name
|
||||
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
int node_index = -1;
|
||||
if (node_indexes.size() == 1) {
|
||||
node_index = node_indexes[0];
|
||||
} else if (node_indexes.empty()) {
|
||||
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" was not found");
|
||||
} else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous
|
||||
{
|
||||
// many nodes with the same name
|
||||
// check if some of found index matches input name
|
||||
int matched_inputs_number = 0;
|
||||
for (const auto& index : node_indexes) {
|
||||
if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) {
|
||||
node_index = index;
|
||||
++matched_inputs_number;
|
||||
}
|
||||
}
|
||||
if (matched_inputs_number == 0) {
|
||||
throw ngraph_error("Input edge described by: " + node.m_node_name + " and input name: " + in.m_input_name +
|
||||
int node_index = node.m_node_index;
|
||||
if (node_index == -1) { // the node index is not provided
|
||||
// identification can be both based on node name and output name (if the node index is not provided)
|
||||
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
if (node_indexes.size() == 1) {
|
||||
node_index = node_indexes[0];
|
||||
} else if (node_indexes.empty()) {
|
||||
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" was not found");
|
||||
} else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous
|
||||
{
|
||||
// many nodes with the same name
|
||||
// check if some of found index matches input name
|
||||
int matched_inputs_number = 0;
|
||||
for (const auto& index : node_indexes) {
|
||||
if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) {
|
||||
node_index = index;
|
||||
++matched_inputs_number;
|
||||
}
|
||||
}
|
||||
if (matched_inputs_number == 0) {
|
||||
throw ngraph_error("Input edge described by: " + node.m_node_name +
|
||||
" and input name: " + in.m_input_name + " was not found");
|
||||
}
|
||||
if (matched_inputs_number > 1) {
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name +
|
||||
" are ambiguous to determine input edge");
|
||||
}
|
||||
} else {
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " +
|
||||
std::to_string(in.m_input_index) + " are ambiguous to determine input edge");
|
||||
}
|
||||
if (matched_inputs_number > 1) {
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name +
|
||||
" are ambiguous to determine input edge");
|
||||
}
|
||||
} else {
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " +
|
||||
std::to_string(in.m_input_index) + " are ambiguous to determine input edge");
|
||||
} else { // the node index is provided
|
||||
check_node_index(node_index);
|
||||
}
|
||||
if (in.m_input_index != -1) // input index is set
|
||||
{
|
||||
@ -146,33 +150,38 @@ InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const
|
||||
}
|
||||
|
||||
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node, const EditorOutput& out) const {
|
||||
// identification can be both based on node name and output name
|
||||
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
int node_index = -1;
|
||||
if (node_indexes.size() == 1) {
|
||||
node_index = node_indexes[0];
|
||||
} else if (node_indexes.empty()) {
|
||||
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" was not found");
|
||||
} else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous
|
||||
{
|
||||
// many nodes with the same name
|
||||
// check if some of found index matches output name
|
||||
int matched_outputs_number = 0;
|
||||
for (const auto& index : node_indexes) {
|
||||
if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) > 0) {
|
||||
node_index = index;
|
||||
++matched_outputs_number;
|
||||
int node_index = node_index = node.m_node_index;
|
||||
if (node_index == -1) { // the node index is not provided
|
||||
// identification can be both based on node name and output name (if the node index is not provided)
|
||||
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
if (node_indexes.size() == 1) {
|
||||
node_index = node_indexes[0];
|
||||
} else if (node_indexes.empty()) {
|
||||
throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" was not found");
|
||||
} else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous
|
||||
{
|
||||
// many nodes with the same name
|
||||
// check if some of found index matches output name
|
||||
int matched_outputs_number = 0;
|
||||
for (const auto& index : node_indexes) {
|
||||
if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) >
|
||||
0) {
|
||||
node_index = index;
|
||||
++matched_outputs_number;
|
||||
}
|
||||
}
|
||||
if (matched_outputs_number == 0) {
|
||||
throw ngraph_error("Output edge described by: " + node.m_node_name +
|
||||
" and output name: " + out.m_output_name + " was not found");
|
||||
}
|
||||
} else {
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " +
|
||||
std::to_string(out.m_output_index) + " are ambiguous to determine output edge");
|
||||
}
|
||||
if (matched_outputs_number == 0) {
|
||||
throw ngraph_error("Output edge described by: " + node.m_node_name +
|
||||
" and output name: " + out.m_output_name + " was not found");
|
||||
}
|
||||
} else {
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " +
|
||||
std::to_string(out.m_output_index) + " are ambiguous to determine output edge");
|
||||
} else { // the node index is provided
|
||||
check_node_index(node_index);
|
||||
}
|
||||
if (out.m_output_index != -1) // output index is set
|
||||
{
|
||||
@ -197,16 +206,45 @@ std::vector<InputEdge> onnx_editor::EdgeMapper::find_output_consumers(const std:
|
||||
const auto node_idx = it->second;
|
||||
const auto port_indexes = get_node_input_indexes(node_idx, output_name);
|
||||
for (const auto& idx : port_indexes) {
|
||||
input_edges.push_back(InputEdge{node_idx, idx});
|
||||
const auto consumer_edge = InputEdge{node_idx, idx};
|
||||
if (std::find_if(std::begin(input_edges), std::end(input_edges), [&consumer_edge](const InputEdge& edge) {
|
||||
return edge.m_node_idx == consumer_edge.m_node_idx && edge.m_port_idx == consumer_edge.m_port_idx;
|
||||
}) == std::end(input_edges)) {
|
||||
// only unique
|
||||
input_edges.push_back(consumer_edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
return input_edges;
|
||||
}
|
||||
|
||||
bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const {
|
||||
if (node.m_node_index >= 0 && node.m_node_index < static_cast<int>(m_node_inputs.size())) {
|
||||
return true;
|
||||
}
|
||||
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void check_node(bool condition, const EditorNode& node) {
|
||||
NGRAPH_CHECK(condition,
|
||||
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
", node_index: " + (node.m_node_index == -1 ? "not_given" : std::to_string(node.m_node_index)) +
|
||||
" is ambiguous");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int onnx_editor::EdgeMapper::get_node_index(const EditorNode& node) const {
|
||||
if (node.m_node_index != -1) { // the node index provided
|
||||
check_node_index(node.m_node_index);
|
||||
return node.m_node_index;
|
||||
}
|
||||
const auto indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
check_node(indexes.size() == 1, node);
|
||||
return indexes[0];
|
||||
}
|
||||
|
||||
bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const {
|
||||
if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index)) {
|
||||
return true;
|
||||
@ -218,20 +256,25 @@ bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) co
|
||||
}
|
||||
|
||||
std::vector<std::string> onnx_editor::EdgeMapper::get_input_ports(const EditorNode& node) const {
|
||||
NGRAPH_CHECK(is_correct_and_unambiguous_node(node),
|
||||
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" is ambiguous");
|
||||
const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
|
||||
check_node(is_correct_and_unambiguous_node(node), node);
|
||||
auto node_index = node.m_node_index;
|
||||
if (node_index == -1) { // the node index is provided
|
||||
node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
|
||||
} else {
|
||||
check_node_index(node_index);
|
||||
}
|
||||
return m_node_inputs[node_index];
|
||||
}
|
||||
|
||||
std::vector<std::string> onnx_editor::EdgeMapper::get_output_ports(const EditorNode& node) const {
|
||||
NGRAPH_CHECK(is_correct_and_unambiguous_node(node),
|
||||
"The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" is ambiguous");
|
||||
const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
|
||||
check_node(is_correct_and_unambiguous_node(node), node);
|
||||
auto node_index = node.m_node_index;
|
||||
if (node_index == -1) // the node index is provided
|
||||
{
|
||||
node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0];
|
||||
} else {
|
||||
check_node_index(node_index);
|
||||
}
|
||||
return m_node_outputs[node_index];
|
||||
}
|
||||
|
||||
@ -250,3 +293,8 @@ std::string onnx_editor::EdgeMapper::get_target_tensor_name(const OutputEdge& ed
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
void onnx_editor::EdgeMapper::check_node_index(int node_index) const {
|
||||
NGRAPH_CHECK(node_index >= 0 && node_index < static_cast<int>(m_node_inputs.size()),
|
||||
"Provided node index: " + std::to_string(node_index) + " is out of scope");
|
||||
}
|
||||
|
@ -92,6 +92,16 @@ public:
|
||||
///
|
||||
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
|
||||
|
||||
/// \brief Returns index (position) of provided node in the graph
|
||||
/// in topological order.
|
||||
///
|
||||
/// \param node An EditorNode helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
/// \note The exception will be thrown if the provided node is ambiguous.
|
||||
///
|
||||
int get_node_index(const EditorNode& node) const;
|
||||
|
||||
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
|
||||
///
|
||||
/// \param name The name of tensor in a graph.
|
||||
@ -130,6 +140,7 @@ private:
|
||||
// note: a single node can have more than one inputs with the same name
|
||||
std::vector<int> get_node_input_indexes(int node_index, const std::string& input_name) const;
|
||||
int get_node_output_idx(int node_index, const std::string& output_name) const;
|
||||
void check_node_index(int node_index) const;
|
||||
|
||||
std::vector<std::vector<std::string>> m_node_inputs;
|
||||
std::vector<std::vector<std::string>> m_node_outputs;
|
||||
|
@ -444,6 +444,11 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN
|
||||
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
|
||||
}
|
||||
|
||||
int onnx_editor::ONNXModelEditor::get_node_index(const EditorNode& node) const {
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.get_node_index(node);
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const {
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.is_correct_tensor_name(name);
|
||||
|
@ -198,6 +198,16 @@ public:
|
||||
///
|
||||
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
|
||||
|
||||
/// \brief Returns index (position) of provided node in the graph
|
||||
/// in topological order.
|
||||
///
|
||||
/// \param node An EditorNode helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
/// \note The exception will be thrown if the provided node is ambiguous.
|
||||
///
|
||||
int get_node_index(const EditorNode& node) const;
|
||||
|
||||
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
|
||||
///
|
||||
/// \param name The name of tensor in a graph.
|
||||
|
@ -114,11 +114,14 @@ struct EditorOutput {
|
||||
/// You can indicate test_node by name as EditorNode("test_node")
|
||||
/// or by assigned output as EditorNode(EditorOutput("out1"))
|
||||
/// or EditorNode(EditorOutput("out2"))
|
||||
/// or you can determine the node by postition of a node in an ONNX graph (in topological order).
|
||||
struct EditorNode {
|
||||
EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {}
|
||||
EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {}
|
||||
EditorNode(const int node_index) : m_node_index{node_index} {}
|
||||
const std::string m_node_name = "";
|
||||
const std::string m_output_name = "";
|
||||
const int m_node_index = -1;
|
||||
};
|
||||
} // namespace onnx_editor
|
||||
} // namespace ngraph
|
||||
|
@ -47,6 +47,18 @@ Place::Ptr PlaceInputEdgeONNX::get_source_tensor() const {
|
||||
return std::make_shared<PlaceTensorONNX>(m_editor->get_source_tensor_name(m_edge), m_editor);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceInputEdgeONNX::get_consuming_operations() const {
|
||||
return {std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor)};
|
||||
}
|
||||
|
||||
Place::Ptr PlaceInputEdgeONNX::get_producing_operation() const {
|
||||
return get_source_tensor()->get_producing_operation();
|
||||
}
|
||||
|
||||
Place::Ptr PlaceInputEdgeONNX::get_producing_port() const {
|
||||
return get_source_tensor()->get_producing_port();
|
||||
}
|
||||
|
||||
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_edge{edge},
|
||||
@ -85,6 +97,18 @@ Place::Ptr PlaceOutputEdgeONNX::get_target_tensor() const {
|
||||
return std::make_shared<PlaceTensorONNX>(m_editor->get_target_tensor_name(m_edge), m_editor);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceOutputEdgeONNX::get_consuming_ports() const {
|
||||
return get_target_tensor()->get_consuming_ports();
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOutputEdgeONNX::get_producing_operation() const {
|
||||
return std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceOutputEdgeONNX::get_consuming_operations() const {
|
||||
return get_target_tensor()->get_consuming_operations();
|
||||
}
|
||||
|
||||
PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_name{name},
|
||||
m_editor{std::move(editor)} {}
|
||||
@ -112,10 +136,8 @@ std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_ports() const {
|
||||
return ret;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const {
|
||||
return std::make_shared<PlaceInputEdgeONNX>(
|
||||
m_editor->find_input_edge(onnx_editor::EditorOutput(m_name), onnx_editor::EditorInput(input_port_index)),
|
||||
m_editor);
|
||||
Place::Ptr PlaceTensorONNX::get_producing_operation() const {
|
||||
return get_producing_port()->get_producing_operation();
|
||||
}
|
||||
|
||||
bool PlaceTensorONNX::is_input() const {
|
||||
@ -146,6 +168,19 @@ bool PlaceTensorONNX::is_equal_data(Place::Ptr another) const {
|
||||
eq_to_consuming_port(another);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_operations() const {
|
||||
std::vector<Place::Ptr> consuming_ports = get_consuming_ports();
|
||||
std::vector<Place::Ptr> consuming_ops;
|
||||
std::transform(std::begin(consuming_ports),
|
||||
std::end(consuming_ports),
|
||||
std::back_inserter(consuming_ops),
|
||||
[](const Place::Ptr& place) {
|
||||
return place->get_consuming_operations().at(0);
|
||||
});
|
||||
|
||||
return consuming_ops;
|
||||
}
|
||||
|
||||
PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_node{node},
|
||||
m_editor{std::move(editor)} {}
|
||||
@ -158,6 +193,10 @@ std::vector<std::string> PlaceOpONNX::get_names() const {
|
||||
return {m_node.m_node_name};
|
||||
}
|
||||
|
||||
onnx_editor::EditorNode PlaceOpONNX::get_editor_node() const {
|
||||
return m_node;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_output_port() const {
|
||||
if (m_editor->get_output_ports(m_node).size() == 1) {
|
||||
return get_output_port(0);
|
||||
@ -209,3 +248,133 @@ Place::Ptr PlaceOpONNX::get_input_port(const std::string& input_name) const {
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_ports() const {
|
||||
std::vector<Place::Ptr> consuming_ports;
|
||||
const auto out_ports_size = m_editor->get_output_ports(m_node).size();
|
||||
for (int out_idx = 0; out_idx < out_ports_size; ++out_idx) {
|
||||
auto consuming_ops_out = get_output_port(out_idx)->get_consuming_ports();
|
||||
consuming_ports.insert(consuming_ports.end(), consuming_ops_out.begin(), consuming_ops_out.end());
|
||||
}
|
||||
return consuming_ports;
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::vector<Place::Ptr> get_consuming_ops(std::vector<Place::Ptr> input_ports) {
|
||||
std::vector<Place::Ptr> consuming_ops;
|
||||
std::transform(std::begin(input_ports),
|
||||
std::end(input_ports),
|
||||
std::back_inserter(consuming_ops),
|
||||
[](const Place::Ptr place) {
|
||||
return place->get_consuming_operations().at(0);
|
||||
});
|
||||
|
||||
return consuming_ops;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations() const {
|
||||
std::vector<Place::Ptr> consuming_ports = get_consuming_ports();
|
||||
return get_consuming_ops(consuming_ports);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations(int output_port_index) const {
|
||||
std::vector<Place::Ptr> consuming_ports = get_output_port(output_port_index)->get_consuming_ports();
|
||||
return get_consuming_ops(consuming_ports);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceOpONNX::get_consuming_operations(const std::string& output_port_name) const {
|
||||
std::vector<Place::Ptr> consuming_ports = get_output_port(output_port_name)->get_consuming_ports();
|
||||
return get_consuming_ops(consuming_ports);
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_producing_operation() const {
|
||||
const auto input_port = get_input_port();
|
||||
if (input_port != nullptr) {
|
||||
return input_port->get_producing_operation();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_producing_operation(int input_port_index) const {
|
||||
const auto input_port = get_input_port(input_port_index);
|
||||
if (input_port != nullptr) {
|
||||
return input_port->get_producing_operation();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_producing_operation(const std::string& input_port_name) const {
|
||||
const auto input_port = get_input_port(input_port_name);
|
||||
if (input_port != nullptr) {
|
||||
return input_port->get_producing_operation();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool PlaceOpONNX::is_equal(Place::Ptr another) const {
|
||||
if (const auto place_op = std::dynamic_pointer_cast<PlaceOpONNX>(another)) {
|
||||
const auto& another_node = place_op->get_editor_node();
|
||||
if (m_editor->is_correct_and_unambiguous_node(m_node) ||
|
||||
m_editor->is_correct_and_unambiguous_node(another_node)) {
|
||||
return m_editor->get_node_index(m_node) == m_editor->get_node_index(another_node);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_target_tensor() const {
|
||||
const auto output_port = get_output_port();
|
||||
if (output_port != nullptr) {
|
||||
return output_port->get_target_tensor();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_target_tensor(int output_port_index) const {
|
||||
const auto output_port = get_output_port(output_port_index);
|
||||
if (output_port != nullptr) {
|
||||
return output_port->get_target_tensor();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_target_tensor(const std::string& output_name) const {
|
||||
const auto output_port = get_output_port(output_name);
|
||||
if (output_port != nullptr) {
|
||||
return output_port->get_target_tensor();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_source_tensor() const {
|
||||
const auto input_port = get_input_port();
|
||||
if (input_port != nullptr) {
|
||||
return input_port->get_source_tensor();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_source_tensor(int input_port_index) const {
|
||||
const auto input_port = get_input_port(input_port_index);
|
||||
if (input_port != nullptr) {
|
||||
return input_port->get_source_tensor();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceOpONNX::get_source_tensor(const std::string& input_name) const {
|
||||
const auto input_port = get_input_port(input_name);
|
||||
if (input_port != nullptr) {
|
||||
return input_port->get_source_tensor();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool PlaceOpONNX::is_input() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlaceOpONNX::is_output() const {
|
||||
return false;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <editor.hpp>
|
||||
#include <frontend_manager/place.hpp>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
@ -15,17 +16,18 @@ public:
|
||||
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
PlaceInputEdgeONNX(onnx_editor::InputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
|
||||
// internal usage
|
||||
onnx_editor::InputEdge get_input_edge() const;
|
||||
|
||||
// external usage
|
||||
bool is_input() const override;
|
||||
|
||||
bool is_output() const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
bool is_equal_data(Place::Ptr another) const override;
|
||||
|
||||
Place::Ptr get_source_tensor() const override;
|
||||
std::vector<Place::Ptr> get_consuming_operations() const override;
|
||||
Place::Ptr get_producing_operation() const override;
|
||||
Place::Ptr get_producing_port() const override;
|
||||
|
||||
private:
|
||||
onnx_editor::InputEdge m_edge;
|
||||
@ -37,17 +39,18 @@ public:
|
||||
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
PlaceOutputEdgeONNX(onnx_editor::OutputEdge&& edge, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
|
||||
// internal usage
|
||||
onnx_editor::OutputEdge get_output_edge() const;
|
||||
|
||||
// external usage
|
||||
bool is_input() const override;
|
||||
|
||||
bool is_output() const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
bool is_equal_data(Place::Ptr another) const override;
|
||||
|
||||
Place::Ptr get_target_tensor() const override;
|
||||
std::vector<Place::Ptr> get_consuming_ports() const override;
|
||||
Place::Ptr get_producing_operation() const override;
|
||||
std::vector<Place::Ptr> get_consuming_operations() const override;
|
||||
|
||||
private:
|
||||
onnx_editor::OutputEdge m_edge;
|
||||
@ -59,21 +62,16 @@ public:
|
||||
PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
PlaceTensorONNX(std::string&& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
|
||||
// external usage
|
||||
std::vector<std::string> get_names() const override;
|
||||
|
||||
Place::Ptr get_producing_port() const override;
|
||||
|
||||
std::vector<Place::Ptr> get_consuming_ports() const override;
|
||||
|
||||
Ptr get_input_port(int input_port_index) const override;
|
||||
|
||||
Place::Ptr get_producing_operation() const override;
|
||||
bool is_input() const override;
|
||||
|
||||
bool is_output() const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
bool is_equal_data(Place::Ptr another) const override;
|
||||
std::vector<Place::Ptr> get_consuming_operations() const override;
|
||||
|
||||
private:
|
||||
std::string m_name;
|
||||
@ -86,6 +84,10 @@ public:
|
||||
PlaceOpONNX(onnx_editor::EditorNode&& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
std::vector<std::string> get_names() const override;
|
||||
|
||||
// internal usage
|
||||
onnx_editor::EditorNode get_editor_node() const;
|
||||
|
||||
// external usage
|
||||
Place::Ptr get_output_port() const override;
|
||||
Place::Ptr get_output_port(int output_port_index) const override;
|
||||
Place::Ptr get_output_port(const std::string& output_port_name) const override;
|
||||
@ -94,6 +96,27 @@ public:
|
||||
Place::Ptr get_input_port(int input_port_index) const override;
|
||||
Place::Ptr get_input_port(const std::string& input_name) const override;
|
||||
|
||||
std::vector<Place::Ptr> get_consuming_ports() const override;
|
||||
std::vector<Place::Ptr> get_consuming_operations() const override;
|
||||
std::vector<Place::Ptr> get_consuming_operations(int output_port_index) const override;
|
||||
std::vector<Place::Ptr> get_consuming_operations(const std::string& output_port_name) const override;
|
||||
|
||||
Place::Ptr get_producing_operation() const override;
|
||||
Place::Ptr get_producing_operation(int input_port_index) const override;
|
||||
Place::Ptr get_producing_operation(const std::string& input_port_name) const override;
|
||||
|
||||
Place::Place::Ptr get_target_tensor() const override;
|
||||
Place::Ptr get_target_tensor(int output_port_index) const override;
|
||||
Place::Ptr get_target_tensor(const std::string& output_name) const override;
|
||||
|
||||
Place::Place::Ptr get_source_tensor() const override;
|
||||
Place::Ptr get_source_tensor(int input_port_index) const override;
|
||||
Place::Ptr get_source_tensor(const std::string& input_name) const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
bool is_input() const override;
|
||||
bool is_output() const override;
|
||||
|
||||
private:
|
||||
onnx_editor::EditorNode m_node;
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
|
@ -693,6 +693,26 @@ NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_ind
|
||||
EXPECT_EQ(edge2.m_new_input_name, "custom_input_name_2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_index) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{0}, EditorInput{0, "custom_input_name_1"});
|
||||
EXPECT_EQ(edge.m_node_idx, 0);
|
||||
EXPECT_EQ(edge.m_port_idx, 0);
|
||||
EXPECT_EQ(edge.m_new_input_name, "custom_input_name_1");
|
||||
|
||||
const InputEdge edge2 = editor.find_input_edge(EditorNode{5}, EditorInput{0});
|
||||
EXPECT_EQ(edge2.m_node_idx, 5);
|
||||
EXPECT_EQ(edge2.m_port_idx, 0);
|
||||
|
||||
try {
|
||||
editor.find_input_edge(EditorNode{99}, EditorInput{"conv1/7x7_s2_1"});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
@ -767,6 +787,21 @@ NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_i
|
||||
EXPECT_EQ(edge2.m_port_idx, 1);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_index) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
const OutputEdge edge = editor.find_output_edge(EditorNode{5}, EditorOutput{1});
|
||||
EXPECT_EQ(edge.m_node_idx, 5);
|
||||
EXPECT_EQ(edge.m_port_idx, 1);
|
||||
|
||||
try {
|
||||
editor.find_output_edge(EditorNode{99}, EditorOutput{"conv1/7x7_s2_1"});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.onnx")};
|
||||
|
||||
@ -1044,6 +1079,12 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"});
|
||||
EXPECT_EQ(is_correct_node, true);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{2});
|
||||
EXPECT_EQ(is_correct_node, true);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{99});
|
||||
EXPECT_EQ(is_correct_node, false);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}});
|
||||
EXPECT_EQ(is_correct_node, false);
|
||||
|
||||
@ -1054,6 +1095,32 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) {
|
||||
EXPECT_EQ(is_correct_node, false);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_get_node_index) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")};
|
||||
|
||||
EXPECT_EQ(editor.get_node_index(EditorNode{2}), 2);
|
||||
EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"relu1"}}), 0);
|
||||
EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"split2"}}), 5);
|
||||
EXPECT_EQ(editor.get_node_index(EditorNode{"relu1_name"}), 0);
|
||||
|
||||
try {
|
||||
editor.get_node_index(EditorNode{99});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos);
|
||||
}
|
||||
|
||||
try {
|
||||
editor.get_node_index(EditorNode{"add_ambiguous_name"});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(
|
||||
msg.find(
|
||||
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_input_edge_from_tensor_with_single_consumer) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")};
|
||||
|
||||
@ -1365,15 +1432,18 @@ NGRAPH_TEST(onnx_editor, get_input_ports) {
|
||||
editor.get_input_ports(EditorNode{"add_ambiguous_name"});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
EXPECT_TRUE(
|
||||
msg.find(
|
||||
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
}
|
||||
try {
|
||||
editor.get_input_ports(EditorNode{""});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
EXPECT_TRUE(
|
||||
msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
NGRAPH_TEST(onnx_editor, get_output_ports) {
|
||||
@ -1393,14 +1463,17 @@ NGRAPH_TEST(onnx_editor, get_output_ports) {
|
||||
editor.get_output_ports(EditorNode{"add_ambiguous_name"});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
EXPECT_TRUE(
|
||||
msg.find(
|
||||
"The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
}
|
||||
try {
|
||||
editor.get_output_ports(EditorNode{""});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
EXPECT_TRUE(
|
||||
msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
@ -89,13 +89,23 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
place.def(
|
||||
"get_consuming_operations",
|
||||
[](const ngraph::frontend::Place& self, py::object outputPortIndex) {
|
||||
if (outputPortIndex == py::none()) {
|
||||
return self.get_consuming_operations();
|
||||
[](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) {
|
||||
if (outputName == py::none()) {
|
||||
if (outputPortIndex == py::none()) {
|
||||
return self.get_consuming_operations();
|
||||
} else {
|
||||
return self.get_consuming_operations(py::cast<int>(outputPortIndex));
|
||||
}
|
||||
} else {
|
||||
return self.get_consuming_operations(py::cast<int>(outputPortIndex));
|
||||
if (outputPortIndex == py::none()) {
|
||||
return self.get_consuming_operations(py::cast<std::string>(outputName));
|
||||
} else {
|
||||
return self.get_consuming_operations(py::cast<std::string>(outputName),
|
||||
py::cast<int>(outputPortIndex));
|
||||
}
|
||||
}
|
||||
},
|
||||
py::arg("outputName") = py::none(),
|
||||
py::arg("outputPortIndex") = py::none(),
|
||||
R"(
|
||||
Returns references to all operation nodes that consume data from this place for specified output port.
|
||||
@ -103,6 +113,8 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outputName : str
|
||||
Name of output port group. May not be set if node has one output port group.
|
||||
outputPortIndex : int
|
||||
If place is an operational node it specifies which output port should be considered
|
||||
May not be set if node has only one output port.
|
||||
@ -115,13 +127,22 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
place.def(
|
||||
"get_target_tensor",
|
||||
[](const ngraph::frontend::Place& self, py::object outputPortIndex) {
|
||||
if (outputPortIndex == py::none()) {
|
||||
return self.get_target_tensor();
|
||||
[](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) {
|
||||
if (outputName == py::none()) {
|
||||
if (outputPortIndex == py::none()) {
|
||||
return self.get_target_tensor();
|
||||
} else {
|
||||
return self.get_target_tensor(py::cast<int>(outputPortIndex));
|
||||
}
|
||||
} else {
|
||||
return self.get_target_tensor(py::cast<int>(outputPortIndex));
|
||||
if (outputPortIndex == py::none()) {
|
||||
return self.get_target_tensor(py::cast<std::string>(outputName));
|
||||
} else {
|
||||
return self.get_target_tensor(py::cast<std::string>(outputName), py::cast<int>(outputPortIndex));
|
||||
}
|
||||
}
|
||||
},
|
||||
py::arg("outputName") = py::none(),
|
||||
py::arg("outputPortIndex") = py::none(),
|
||||
R"(
|
||||
Returns a tensor place that gets data from this place; applicable for operations,
|
||||
@ -129,6 +150,8 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outputName : str
|
||||
Name of output port group. May not be set if node has one output port group.
|
||||
outputPortIndex : int
|
||||
Output port index if the current place is an operation node and has multiple output ports.
|
||||
May not be set if place has only one output port.
|
||||
@ -141,19 +164,31 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
place.def(
|
||||
"get_producing_operation",
|
||||
[](const ngraph::frontend::Place& self, py::object inputPortIndex) {
|
||||
if (inputPortIndex == py::none()) {
|
||||
return self.get_producing_operation();
|
||||
[](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) {
|
||||
if (inputName == py::none()) {
|
||||
if (inputPortIndex == py::none()) {
|
||||
return self.get_producing_operation();
|
||||
} else {
|
||||
return self.get_producing_operation(py::cast<int>(inputPortIndex));
|
||||
}
|
||||
} else {
|
||||
return self.get_producing_operation(py::cast<int>(inputPortIndex));
|
||||
if (inputPortIndex == py::none()) {
|
||||
return self.get_producing_operation(py::cast<std::string>(inputName));
|
||||
} else {
|
||||
return self.get_producing_operation(py::cast<std::string>(inputName),
|
||||
py::cast<int>(inputPortIndex));
|
||||
}
|
||||
}
|
||||
},
|
||||
py::arg("inputName") = py::none(),
|
||||
py::arg("inputPortIndex") = py::none(),
|
||||
R"(
|
||||
Get an operation node place that immediately produces data for this place.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputName : str
|
||||
Name of port group. May not be set if node has one input port group.
|
||||
inputPortIndex : int
|
||||
If a given place is itself an operation node, this specifies a port index.
|
||||
May not be set if place has only one input port.
|
||||
@ -260,13 +295,22 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
place.def(
|
||||
"get_source_tensor",
|
||||
[](const ngraph::frontend::Place& self, py::object inputPortIndex) {
|
||||
if (inputPortIndex == py::none()) {
|
||||
return self.get_source_tensor();
|
||||
[](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) {
|
||||
if (inputName == py::none()) {
|
||||
if (inputPortIndex == py::none()) {
|
||||
return self.get_source_tensor();
|
||||
} else {
|
||||
return self.get_source_tensor(py::cast<int>(inputPortIndex));
|
||||
}
|
||||
} else {
|
||||
return self.get_source_tensor(py::cast<int>(inputPortIndex));
|
||||
if (inputPortIndex == py::none()) {
|
||||
return self.get_source_tensor(py::cast<std::string>(inputName));
|
||||
} else {
|
||||
return self.get_source_tensor(py::cast<std::string>(inputName), py::cast<int>(inputPortIndex));
|
||||
}
|
||||
}
|
||||
},
|
||||
py::arg("inputName") = py::none(),
|
||||
py::arg("inputPortIndex") = py::none(),
|
||||
R"(
|
||||
Returns a tensor place that supplies data for this place; applicable for operations,
|
||||
@ -274,6 +318,8 @@ void regclass_pyngraph_Place(py::module m) {
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputName : str
|
||||
Name of port group. May not be set if node has one input port group.
|
||||
inputPortIndex : int
|
||||
Input port index for operational node. May not be specified if place has only one input port.
|
||||
|
||||
|
@ -107,12 +107,29 @@ public:
|
||||
std::vector<Place::Ptr> get_consuming_operations() const override {
|
||||
m_stat.m_get_consuming_operations++;
|
||||
m_stat.m_lastArgInt = -1;
|
||||
m_stat.m_lastArgString = "";
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> get_consuming_operations(int outputPortIndex) const override {
|
||||
m_stat.m_get_consuming_operations++;
|
||||
m_stat.m_lastArgInt = outputPortIndex;
|
||||
m_stat.m_lastArgString = "";
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> get_consuming_operations(const std::string& outputName) const override {
|
||||
m_stat.m_get_consuming_operations++;
|
||||
m_stat.m_lastArgInt = -1;
|
||||
m_stat.m_lastArgString = outputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> get_consuming_operations(const std::string& outputName,
|
||||
int outputPortIndex) const override {
|
||||
m_stat.m_get_consuming_operations++;
|
||||
m_stat.m_lastArgInt = outputPortIndex;
|
||||
m_stat.m_lastArgString = outputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
@ -128,6 +145,20 @@ public:
|
||||
return std::make_shared<PlaceMockPy>();
|
||||
}
|
||||
|
||||
Place::Ptr get_target_tensor(const std::string& outputName) const override {
|
||||
m_stat.m_get_target_tensor++;
|
||||
m_stat.m_lastArgInt = -1;
|
||||
m_stat.m_lastArgString = outputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
Place::Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const override {
|
||||
m_stat.m_get_target_tensor++;
|
||||
m_stat.m_lastArgInt = outputPortIndex;
|
||||
m_stat.m_lastArgString = outputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
Place::Ptr get_producing_operation() const override {
|
||||
m_stat.m_get_producing_operation++;
|
||||
m_stat.m_lastArgInt = -1;
|
||||
@ -140,6 +171,20 @@ public:
|
||||
return std::make_shared<PlaceMockPy>();
|
||||
}
|
||||
|
||||
Place::Ptr get_producing_operation(const std::string& inputName) const override {
|
||||
m_stat.m_get_producing_operation++;
|
||||
m_stat.m_lastArgInt = -1;
|
||||
m_stat.m_lastArgString = inputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
Place::Ptr get_producing_operation(const std::string& inputName, int inputPortIndex) const override {
|
||||
m_stat.m_get_producing_operation++;
|
||||
m_stat.m_lastArgInt = inputPortIndex;
|
||||
m_stat.m_lastArgString = inputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
Place::Ptr get_producing_port() const override {
|
||||
m_stat.m_get_producing_port++;
|
||||
return std::make_shared<PlaceMockPy>();
|
||||
@ -236,6 +281,20 @@ public:
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
Place::Ptr get_source_tensor(const std::string& inputName) const override {
|
||||
m_stat.m_get_source_tensor++;
|
||||
m_stat.m_lastArgInt = -1;
|
||||
m_stat.m_lastArgString = inputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
Place::Ptr get_source_tensor(const std::string& inputName, int inputPortIndex) const override {
|
||||
m_stat.m_get_source_tensor++;
|
||||
m_stat.m_lastArgInt = inputPortIndex;
|
||||
m_stat.m_lastArgString = inputName;
|
||||
return {std::make_shared<PlaceMockPy>()};
|
||||
}
|
||||
|
||||
//---------------Stat--------------------
|
||||
PlaceStat get_stat() const {
|
||||
return m_stat;
|
||||
|
@ -9,6 +9,7 @@ from ngraph import PartialShape
|
||||
from ngraph.frontend import FrontEndManager
|
||||
|
||||
|
||||
# ------Test input model 1------
|
||||
# in1 in2 in3
|
||||
# | | |
|
||||
# \ / |
|
||||
@ -24,9 +25,32 @@ from ngraph.frontend import FrontEndManager
|
||||
# / \ |
|
||||
# out1 out2 out4
|
||||
#
|
||||
#
|
||||
# ------Test input model 2------
|
||||
# in1 in2
|
||||
# | |
|
||||
# \ /
|
||||
# +--------+
|
||||
# | Add |
|
||||
# +--------+
|
||||
# <add_out>
|
||||
# |
|
||||
# +--------+
|
||||
# | Split |
|
||||
# |(split2)|
|
||||
# +--------+
|
||||
# / \
|
||||
# <sp_out1> <sp_out2>
|
||||
# +-------+ +-------+
|
||||
# | Abs | | Sin |
|
||||
# | (abs1)| | |
|
||||
# +------ + +-------+
|
||||
# | |
|
||||
# out1 out2
|
||||
#
|
||||
def create_test_onnx_models():
|
||||
models = {}
|
||||
# Input model
|
||||
# Input model 1
|
||||
add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"])
|
||||
split = onnx.helper.make_node("Split", inputs=["add_out"],
|
||||
outputs=["out1", "out2"], name="split1", axis=0)
|
||||
@ -48,6 +72,24 @@ def create_test_onnx_models():
|
||||
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Input model 2
|
||||
split_2 = onnx.helper.make_node("Split", inputs=["add_out"],
|
||||
outputs=["sp_out1", "sp_out2"], name="split2", axis=0)
|
||||
abs = onnx.helper.make_node("Abs", inputs=["sp_out1"], outputs=["out1"], name="abs1")
|
||||
sin = onnx.helper.make_node("Sin", inputs=["sp_out2"], outputs=["out2"])
|
||||
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
]
|
||||
graph = make_graph([add, split_2, abs, sin], "test_graph_2", input_tensors, output_tensors)
|
||||
models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
@ -312,8 +354,9 @@ def test_extract_subgraph_4():
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
|
||||
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
|
||||
place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
|
||||
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
place5 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
@ -377,8 +420,9 @@ def test_override_all_inputs():
|
||||
|
||||
place1 = model.get_place_by_operation_name_and_input_port(
|
||||
operationName="split1", inputPortIndex=0)
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
|
||||
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
|
||||
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="in3")
|
||||
model.override_all_inputs(inputs=[place1, place2, place3, place4])
|
||||
result_func = fe.convert(model)
|
||||
@ -432,11 +476,15 @@ def test_is_input_output():
|
||||
assert not place3.is_input()
|
||||
assert not place3.is_output()
|
||||
|
||||
place4 = place1 = model.get_place_by_operation_name_and_input_port(
|
||||
place4 = model.get_place_by_operation_name_and_input_port(
|
||||
operationName="split1", inputPortIndex=0)
|
||||
assert not place4.is_input()
|
||||
assert not place4.is_output()
|
||||
|
||||
place5 = model.get_place_by_operation_name(operationName="split1")
|
||||
assert not place5.is_input()
|
||||
assert not place5.is_output()
|
||||
|
||||
|
||||
def test_set_partial_shape():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
@ -520,12 +568,14 @@ def test_is_equal():
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
assert place2.is_equal(place2)
|
||||
|
||||
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
|
||||
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
assert place3.is_equal(place4)
|
||||
|
||||
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
|
||||
place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place6 = model.get_place_by_tensor_name(tensorName="out1").get_input_port(inputPortIndex=0)
|
||||
place6 = out1_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
assert place5.is_equal(place6)
|
||||
|
||||
place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port()
|
||||
@ -538,6 +588,10 @@ def test_is_equal():
|
||||
assert not place6.is_equal(place7)
|
||||
assert not place8.is_equal(place2)
|
||||
|
||||
place9 = model.get_place_by_operation_name(operationName="split1")
|
||||
assert place2.get_producing_operation().is_equal(place9)
|
||||
assert not place9.is_equal(place2)
|
||||
|
||||
|
||||
def test_is_equal_data():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
@ -560,11 +614,12 @@ def test_is_equal_data():
|
||||
place4 = place2.get_producing_port()
|
||||
assert place2.is_equal_data(place4)
|
||||
|
||||
place5 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
|
||||
place5 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
assert place2.is_equal_data(place5)
|
||||
assert place4.is_equal_data(place5)
|
||||
|
||||
place6 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
|
||||
place6 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
|
||||
assert place6.is_equal_data(place5)
|
||||
|
||||
place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
@ -648,3 +703,198 @@ def test_get_input_port():
|
||||
|
||||
assert not split_op.get_input_port(inputPortIndex=1)
|
||||
assert not split_op.get_input_port(inputName="not_existed")
|
||||
|
||||
|
||||
def test_get_consuming_ports():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
add_tensor_consuming_ports = place1.get_consuming_ports()
|
||||
assert len(add_tensor_consuming_ports) == 3
|
||||
place2 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
assert add_tensor_consuming_ports[0].is_equal(place2)
|
||||
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
|
||||
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
assert add_tensor_consuming_ports[1].is_equal(place3)
|
||||
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
|
||||
assert add_tensor_consuming_ports[2].is_equal(place4)
|
||||
|
||||
add_op_consuming_ports = place1.get_producing_operation().get_consuming_ports()
|
||||
assert len(add_op_consuming_ports) == len(add_tensor_consuming_ports)
|
||||
for i in range(len(add_op_consuming_ports)):
|
||||
assert add_op_consuming_ports[i].is_equal(add_tensor_consuming_ports[i])
|
||||
|
||||
|
||||
def test_get_consuming_ports_2():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
split_op_consuming_ports = split_op.get_consuming_ports()
|
||||
assert len(split_op_consuming_ports) == 2
|
||||
abs_input_port = model.get_place_by_operation_name(operationName="abs1").get_input_port(inputPortIndex=0)
|
||||
assert split_op_consuming_ports[0].is_equal(abs_input_port)
|
||||
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
|
||||
sin_input_port = out2_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
assert split_op_consuming_ports[1].is_equal(sin_input_port)
|
||||
|
||||
split_out_port_0 = split_op.get_output_port(outputPortIndex=0)
|
||||
split_out_port_0_consuming_ports = split_out_port_0.get_consuming_ports()
|
||||
assert len(split_out_port_0_consuming_ports) == 1
|
||||
assert split_out_port_0_consuming_ports[0].is_equal(abs_input_port)
|
||||
|
||||
|
||||
def test_get_producing_operation():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
split_tensor_out_2 = model.get_place_by_tensor_name(tensorName="sp_out2")
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
assert split_tensor_out_2.get_producing_operation().is_equal(split_op)
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
split_out_port_2 = split_op.get_output_port(outputPortIndex=1)
|
||||
assert split_out_port_2.get_producing_operation().is_equal(split_op)
|
||||
|
||||
|
||||
def test_get_producing_operation_2():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
abs_op = model.get_place_by_operation_name(operationName="abs1")
|
||||
abs_port_0 = abs_op.get_input_port()
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
assert abs_port_0.get_producing_operation().is_equal(split_op)
|
||||
assert abs_op.get_producing_operation().is_equal(split_op)
|
||||
|
||||
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
add_op = add_out_tensor.get_producing_operation()
|
||||
assert not add_op.get_producing_operation()
|
||||
|
||||
split_op_producing_op = split_op.get_producing_operation(inputName="add_out")
|
||||
assert split_op_producing_op.is_equal(add_op)
|
||||
|
||||
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
|
||||
sin_op = out2_tensor.get_producing_operation()
|
||||
assert sin_op.get_producing_operation(inputPortIndex=0).is_equal(split_op)
|
||||
|
||||
|
||||
def test_get_consuming_operations():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
split_op_consuming_ops = split_op.get_consuming_operations()
|
||||
abs_op = model.get_place_by_operation_name(operationName="abs1")
|
||||
sin_op = model.get_place_by_tensor_name(tensorName="out2").get_producing_operation()
|
||||
|
||||
assert len(split_op_consuming_ops) == 2
|
||||
assert split_op_consuming_ops[0].is_equal(abs_op)
|
||||
assert split_op_consuming_ops[1].is_equal(sin_op)
|
||||
|
||||
split_op_port = split_op.get_input_port(inputPortIndex=0)
|
||||
split_op_port_consuming_ops = split_op_port.get_consuming_operations()
|
||||
|
||||
assert len(split_op_port_consuming_ops) == 1
|
||||
assert split_op_port_consuming_ops[0].is_equal(split_op)
|
||||
|
||||
add_out_port = model.get_place_by_tensor_name(tensorName="add_out").get_producing_port()
|
||||
add_out_port_consuming_ops = add_out_port.get_consuming_operations()
|
||||
assert len(add_out_port_consuming_ops) == 1
|
||||
assert add_out_port_consuming_ops[0].is_equal(split_op)
|
||||
|
||||
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
|
||||
sp_out2_tensor_consuming_ops = sp_out2_tensor.get_consuming_operations()
|
||||
assert len(sp_out2_tensor_consuming_ops) == 1
|
||||
assert sp_out2_tensor_consuming_ops[0].is_equal(sin_op)
|
||||
|
||||
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
|
||||
out2_tensor_consuming_ops = out2_tensor.get_consuming_operations()
|
||||
assert len(out2_tensor_consuming_ops) == 0
|
||||
out2_port_consuming_ops = out2_tensor.get_producing_port().get_consuming_operations()
|
||||
assert len(out2_port_consuming_ops) == 0
|
||||
|
||||
split_out_1_consuming_ops = split_op.get_consuming_operations(outputPortIndex=1)
|
||||
assert len(split_out_1_consuming_ops) == 1
|
||||
split_out_sp_out_2_consuming_ops = split_op.get_consuming_operations(outputName="sp_out2")
|
||||
assert len(split_out_sp_out_2_consuming_ops) == 1
|
||||
assert split_out_1_consuming_ops[0].is_equal(split_out_sp_out_2_consuming_ops[0])
|
||||
assert split_out_1_consuming_ops[0].is_equal(sin_op)
|
||||
|
||||
|
||||
def test_get_target_tensor():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
assert not split_op.get_target_tensor()
|
||||
|
||||
split_op_tensor_1 = split_op.get_target_tensor(outputPortIndex=1)
|
||||
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
|
||||
assert split_op_tensor_1.is_equal(sp_out2_tensor)
|
||||
|
||||
split_tensor_sp_out2 = split_op.get_target_tensor(outputName="sp_out2")
|
||||
assert split_tensor_sp_out2.is_equal(split_op_tensor_1)
|
||||
|
||||
abs_op = model.get_place_by_operation_name(operationName="abs1")
|
||||
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
|
||||
assert abs_op.get_target_tensor().is_equal(out1_tensor)
|
||||
|
||||
|
||||
def test_get_source_tensor():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
add_op = add_out_tensor.get_producing_operation()
|
||||
assert not add_op.get_source_tensor()
|
||||
|
||||
add_op_in_tensor_1 = add_op.get_source_tensor(inputPortIndex=1)
|
||||
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
|
||||
assert add_op_in_tensor_1.is_equal(in2_tensor)
|
||||
|
||||
add_op_in_tensor_in2 = add_op.get_source_tensor(inputName="in2")
|
||||
assert add_op_in_tensor_in2.is_equal(in2_tensor)
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
assert split_op.get_source_tensor().is_equal(add_out_tensor)
|
||||
|
||||
|
||||
def test_get_producing_port():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
model = fe.load("input_model_2.onnx")
|
||||
assert model
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split2")
|
||||
split_op_in_port = split_op.get_input_port()
|
||||
split_op_in_port_prod_port = split_op_in_port.get_producing_port()
|
||||
|
||||
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
add_op = add_out_tensor.get_producing_operation()
|
||||
add_op_out_port = add_op.get_output_port()
|
||||
|
||||
assert split_op_in_port_prod_port.is_equal(add_op_out_port)
|
||||
|
@ -440,6 +440,16 @@ def test_place_get_consuming_operations():
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_consuming_operations == 2
|
||||
assert stat.lastArgInt == -1
|
||||
assert place.get_consuming_operations(outputName="2") is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_consuming_operations == 3
|
||||
assert stat.lastArgInt == -1
|
||||
assert stat.lastArgString == "2"
|
||||
assert place.get_consuming_operations(outputName="3", outputPortIndex=33) is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_consuming_operations == 4
|
||||
assert stat.lastArgInt == 33
|
||||
assert stat.lastArgString == "3"
|
||||
|
||||
|
||||
@mock_needed
|
||||
@ -453,6 +463,16 @@ def test_place_get_target_tensor():
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_target_tensor == 2
|
||||
assert stat.lastArgInt == -1
|
||||
assert place.get_target_tensor(outputName="2") is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_target_tensor == 3
|
||||
assert stat.lastArgInt == -1
|
||||
assert stat.lastArgString == "2"
|
||||
assert place.get_target_tensor(outputName="3", outputPortIndex=33) is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_target_tensor == 4
|
||||
assert stat.lastArgInt == 33
|
||||
assert stat.lastArgString == "3"
|
||||
|
||||
|
||||
@mock_needed
|
||||
@ -466,6 +486,16 @@ def test_place_get_producing_operation():
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_producing_operation == 2
|
||||
assert stat.lastArgInt == -1
|
||||
assert place.get_producing_operation(inputName="2") is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_producing_operation == 3
|
||||
assert stat.lastArgInt == -1
|
||||
assert stat.lastArgString == "2"
|
||||
assert place.get_producing_operation(inputName="3", inputPortIndex=33) is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_producing_operation == 4
|
||||
assert stat.lastArgInt == 33
|
||||
assert stat.lastArgString == "3"
|
||||
|
||||
|
||||
@mock_needed
|
||||
@ -551,3 +581,13 @@ def test_place_get_source_tensor():
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_source_tensor == 2
|
||||
assert stat.lastArgInt == 22
|
||||
assert place.get_source_tensor(inputName="2") is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_source_tensor == 3
|
||||
assert stat.lastArgInt == -1
|
||||
assert stat.lastArgString == "2"
|
||||
assert place.get_source_tensor(inputName="3", inputPortIndex=33) is not None
|
||||
stat = get_place_stat(place)
|
||||
assert stat.get_source_tensor == 4
|
||||
assert stat.lastArgInt == 33
|
||||
assert stat.lastArgString == "3"
|
||||
|
Loading…
Reference in New Issue
Block a user