Integration ONNX Editor with FE API (#6773)

This commit is contained in:
Mateusz Bencer 2021-07-27 12:20:39 +02:00 committed by GitHub
parent 9acedbdacf
commit dc5f44e929
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1203 additions and 90 deletions

View File

@ -277,7 +277,7 @@ void InputModel::set_partial_shape(Place::Ptr place, const ngraph::PartialShape&
ngraph::PartialShape InputModel::get_partial_shape(Place::Ptr place) const
{
FRONT_END_NOT_IMPLEMENTED(set_partial_shape);
FRONT_END_NOT_IMPLEMENTED(get_partial_shape);
}
void InputModel::set_element_type(Place::Ptr place, const ngraph::element::Type&)

View File

@ -2,55 +2,143 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "input_model.hpp"
#include <frontend_manager/frontend_exceptions.hpp>
#include <input_model.hpp>
#include <place.hpp>
#include "place.hpp"
using namespace ngraph;
using namespace ngraph::frontend;
InputModelONNX::InputModelONNX(const std::string& path)
: m_editor(path)
: m_editor{std::make_shared<onnx_editor::ONNXModelEditor>(path)}
{
}
std::vector<Place::Ptr> InputModelONNX::get_inputs() const
{
auto inputs = m_editor.model_inputs();
std::vector<Place::Ptr> ret;
ret.reserve(inputs.size());
const auto& inputs = m_editor->model_inputs();
std::vector<Place::Ptr> in_places;
in_places.reserve(inputs.size());
for (const auto& input : inputs)
{
ret.push_back(std::make_shared<PlaceTensorONNX>(input, m_editor));
in_places.push_back(std::make_shared<PlaceTensorONNX>(input, m_editor));
}
return ret;
return in_places;
}
std::vector<Place::Ptr> InputModelONNX::get_outputs() const
{
const auto& outputs = m_editor->model_outputs();
std::vector<Place::Ptr> out_places;
out_places.reserve(outputs.size());
for (const auto& output : outputs)
{
out_places.push_back(std::make_shared<PlaceTensorONNX>(output, m_editor));
}
return out_places;
}
Place::Ptr InputModelONNX::get_place_by_tensor_name(const std::string& tensor_name) const
{
NGRAPH_CHECK(m_editor->is_correct_tensor_name(tensor_name),
"The tensor with name: " + tensor_name + " does not exist in the graph");
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
}
Place::Ptr
InputModelONNX::get_place_by_operation_name_and_input_port(const std::string& operation_name,
int input_port_index)
{
const auto edge =
m_editor->find_input_edge(onnx_editor::EditorNode(operation_name), input_port_index);
return std::make_shared<PlaceInputEdgeONNX>(edge, m_editor);
}
void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape)
{
std::map<std::string, ngraph::PartialShape> m;
m[place->get_names()[0]] = shape;
m_editor.set_input_shapes(m);
m_editor->set_input_shapes(m);
}
ngraph::PartialShape InputModelONNX::get_partial_shape(Place::Ptr place) const
{
return m_editor->get_tensor_shape(place->get_names().at(0));
}
void InputModelONNX::set_element_type(Place::Ptr place, const ngraph::element::Type& type)
{
std::map<std::string, ngraph::element::Type_t> m;
m[place->get_names()[0]] = type;
m_editor.set_input_types(m);
m_editor->set_input_types(m);
}
std::shared_ptr<Function> InputModelONNX::decode()
{
return m_editor.decode();
return m_editor->decode();
}
std::shared_ptr<Function> InputModelONNX::convert()
{
return m_editor.get_function();
return m_editor->get_function();
}
// Editor features
void InputModelONNX::override_all_outputs(const std::vector<Place::Ptr>& outputs)
{
extract_subgraph({}, outputs);
NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(),
"Unexpected number of outputs after override_all_outputs");
NGRAPH_CHECK(std::all_of(std::begin(outputs),
std::end(outputs),
[](const Place::Ptr& place) { return place->is_output(); }),
"Not all provided arguments of override_all_outputs are new outputs of the model");
}
void InputModelONNX::override_all_inputs(const std::vector<Place::Ptr>& inputs)
{
const auto outputs_before_extraction = m_editor->model_outputs();
extract_subgraph({inputs}, {});
NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction),
std::end(outputs_before_extraction),
std::begin(m_editor->model_outputs())),
"All outputs should be preserved after override_all_inputs. Provided inputs does "
"not satisfy all outputs");
NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(),
"Unexpected number of inputs after override_all_inputs");
}
void InputModelONNX::extract_subgraph(const std::vector<Place::Ptr>& inputs,
const std::vector<Place::Ptr>& outputs)
{
std::vector<onnx_editor::InputEdge> onnx_inputs;
onnx_inputs.reserve(inputs.size());
for (const auto& input : inputs)
{
if (const auto input_port = std::dynamic_pointer_cast<PlaceInputEdgeONNX>(input))
{
onnx_inputs.push_back(input_port->get_input_edge());
}
else if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(input))
{
auto name = tensor->get_names()[0];
const auto consumers = m_editor->find_output_consumers(name);
std::transform(std::begin(consumers),
std::end(consumers),
std::back_inserter(onnx_inputs),
[](const onnx_editor::InputEdge& edge) { return edge; });
}
}
std::vector<onnx_editor::OutputEdge> onnx_outputs;
onnx_outputs.reserve(outputs.size());
for (const auto& output : outputs)
{
const auto output_port = output->get_producing_port();
const auto onnx_output_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output_port);
NGRAPH_CHECK(onnx_output_edge,
"Non-onnx output place was passed as extraction subgraph argument");
onnx_outputs.push_back(onnx_output_edge->get_output_edge());
}
m_editor->cut_graph_fragment(onnx_inputs, onnx_outputs);
}

View File

@ -17,15 +17,25 @@ namespace ngraph
InputModelONNX(const std::string& path);
std::vector<Place::Ptr> get_inputs() const override;
std::vector<Place::Ptr> get_outputs() const override;
Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override;
Place::Ptr get_place_by_operation_name_and_input_port(const std::string& operation_name,
int input_port_index) override;
void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override;
ngraph::PartialShape get_partial_shape(Place::Ptr place) const override;
void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override;
std::shared_ptr<Function> decode();
std::shared_ptr<Function> convert();
// Editor features
void override_all_outputs(const std::vector<Place::Ptr>& outputs) override;
void override_all_inputs(const std::vector<Place::Ptr>& inputs) override;
void extract_subgraph(const std::vector<Place::Ptr>& inputs,
const std::vector<Place::Ptr>& outputs) override;
private:
onnx_editor::ONNXModelEditor m_editor;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
};
} // namespace frontend

View File

@ -0,0 +1,134 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "place.hpp"
#include <frontend_manager/frontend_exceptions.hpp>
using namespace ngraph;
using namespace ngraph::frontend;
PlaceInputEdgeONNX::PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge}
, m_editor{editor}
{
}
onnx_editor::InputEdge PlaceInputEdgeONNX::get_input_edge() const
{
return m_edge;
}
bool PlaceInputEdgeONNX::is_input() const
{
return m_editor->is_input(m_edge);
}
bool PlaceInputEdgeONNX::is_output() const
{
return false;
}
bool PlaceInputEdgeONNX::is_equal(Place::Ptr another) const
{
if (const auto in_edge = std::dynamic_pointer_cast<PlaceInputEdgeONNX>(another))
{
const auto& editor_edge = in_edge->get_input_edge();
return (editor_edge.m_node_idx == m_edge.m_node_idx) &&
(editor_edge.m_port_idx == m_edge.m_port_idx);
}
return false;
}
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge}
, m_editor{editor}
{
}
onnx_editor::OutputEdge PlaceOutputEdgeONNX::get_output_edge() const
{
return m_edge;
}
bool PlaceOutputEdgeONNX::is_input() const
{
return false;
}
bool PlaceOutputEdgeONNX::is_output() const
{
return m_editor->is_output(m_edge);
}
bool PlaceOutputEdgeONNX::is_equal(Place::Ptr another) const
{
if (const auto out_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(another))
{
const auto& editor_edge = out_edge->get_output_edge();
return (editor_edge.m_node_idx == m_edge.m_node_idx) &&
(editor_edge.m_port_idx == m_edge.m_port_idx);
}
return false;
}
PlaceTensorONNX::PlaceTensorONNX(const std::string& name,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_name(name)
, m_editor(editor)
{
}
std::vector<std::string> PlaceTensorONNX::get_names() const
{
return {m_name};
}
Place::Ptr PlaceTensorONNX::get_producing_port() const
{
return std::make_shared<PlaceOutputEdgeONNX>(m_editor->find_output_edge(m_name), m_editor);
}
std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_ports() const
{
std::vector<Place::Ptr> ret;
auto edges = m_editor->find_output_consumers(m_name);
std::transform(edges.begin(),
edges.end(),
std::back_inserter(ret),
[this](const onnx_editor::InputEdge& edge) {
return std::make_shared<PlaceInputEdgeONNX>(edge, this->m_editor);
});
return ret;
}
Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const
{
return std::make_shared<PlaceInputEdgeONNX>(
m_editor->find_input_edge(onnx_editor::EditorOutput(m_name),
onnx_editor::EditorInput(input_port_index)),
m_editor);
}
bool PlaceTensorONNX::is_input() const
{
const auto inputs = m_editor->model_inputs();
return std::find(std::begin(inputs), std::end(inputs), m_name) != std::end(inputs);
}
bool PlaceTensorONNX::is_output() const
{
const auto outputs = m_editor->model_outputs();
return std::find(std::begin(outputs), std::end(outputs), m_name) != std::end(outputs);
}
bool PlaceTensorONNX::is_equal(Place::Ptr another) const
{
if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(another))
{
return m_name == tensor->get_names().at(0);
}
return false;
}

View File

@ -4,7 +4,9 @@
#pragma once
#include <memory>
#include <frontend_manager/place.hpp>
#include <onnx_editor/editor.hpp>
namespace ngraph
{
@ -13,65 +15,63 @@ namespace ngraph
class PlaceInputEdgeONNX : public Place
{
public:
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge)
: m_edge(edge)
{
}
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
onnx_editor::InputEdge get_input_edge() const;
bool is_input() const override;
bool is_output() const override;
bool is_equal(Place::Ptr another) const override;
private:
onnx_editor::InputEdge m_edge;
const std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
};
class PlaceOutputEdgeONNX : public Place
{
public:
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge)
: m_edge(edge)
{
}
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
onnx_editor::OutputEdge get_output_edge() const;
bool is_input() const override;
bool is_output() const override;
bool is_equal(Place::Ptr another) const override;
private:
onnx_editor::OutputEdge m_edge;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
};
class PlaceTensorONNX : public Place
{
public:
PlaceTensorONNX(const std::string& name, const onnx_editor::ONNXModelEditor& editor)
: m_name(name)
, m_editor(editor)
{
}
PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
std::vector<std::string> get_names() const override { return {m_name}; }
std::vector<std::string> get_names() const override;
Place::Ptr get_producing_port() const override
{
return std::make_shared<PlaceOutputEdgeONNX>(m_editor.find_output_edge(m_name));
}
Place::Ptr get_producing_port() const override;
std::vector<Place::Ptr> get_consuming_ports() const override
{
std::vector<Place::Ptr> ret;
auto edges = m_editor.find_output_consumers(m_name);
std::transform(edges.begin(),
edges.end(),
std::back_inserter(ret),
[](const onnx_editor::InputEdge& edge) {
return std::make_shared<PlaceInputEdgeONNX>(edge);
});
return ret;
}
std::vector<Place::Ptr> get_consuming_ports() const override;
Ptr get_input_port(int input_port_index) const override
{
return std::make_shared<PlaceInputEdgeONNX>(m_editor.find_input_edge(
onnx_editor::EditorNode(m_name), onnx_editor::EditorInput(input_port_index)));
}
Ptr get_input_port(int input_port_index) const override;
bool is_input() const override;
bool is_output() const override;
bool is_equal(Place::Ptr another) const override;
private:
std::string m_name;
const onnx_editor::ONNXModelEditor& m_editor;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
};
} // namespace frontend

View File

@ -1,6 +1,7 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/partial_shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ONNX_NAMESPACE
@ -38,5 +39,11 @@ namespace ngraph
///
bool is_supported_ng_type(const element::Type_t& ng_type);
/// \brief Retuns nG PartialShape based on onnx_shape.
///
/// \param onnx_shape A shape of tensor represented in ONNX way.
///
PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape);
} // namespace onnx_common
} // namespace ngraph

View File

@ -88,5 +88,27 @@ namespace ngraph
return NG_2_ONNX_TYPES.count(ng_type) > 0;
}
PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
{
if (onnx_shape.dim_size() == 0)
{
return Shape{}; // empty list of dimensions denotes a scalar
}
std::vector<Dimension> dims;
for (const auto& onnx_dim : onnx_shape.dim())
{
if (onnx_dim.has_dim_value())
{
dims.emplace_back(onnx_dim.dim_value());
}
else // has_dim_param() == true or it is empty dim
{
dims.push_back(Dimension::dynamic());
}
}
return PartialShape{dims};
}
} // namespace onnx_common
} // namespace ngraph

View File

@ -99,6 +99,24 @@ namespace ngraph
///
ONNX_IMPORTER_API bool is_correct_and_unambiguous_node(const EditorNode& node) const;
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
///
/// \param name The name of tensor in a graph.
///
bool is_correct_tensor_name(const std::string& name) const;
/// \brief Get name of input port indicated by the input edge.
///
/// \note Empty string is returned if the port name is not found.
///
std::string get_input_port_name(const InputEdge& edge) const;
/// \brief Get name of output port indicated by the input edge.
///
/// \note Empty string is returned if the port name is not found.
///
std::string get_output_port_name(const OutputEdge& edge) const;
private:
std::vector<int> find_node_indexes(const std::string& node_name,
const std::string& output_name) const;

View File

@ -53,6 +53,12 @@ namespace ngraph
/// the inputs specified in its parameter.
void set_input_shapes(const std::map<std::string, ngraph::PartialShape>& input_shapes);
/// \brief Get shape of ONNX tensor indicated by the tensor_name.
///
/// \param tensor_name The name of ONNX tensor.
///
PartialShape get_tensor_shape(const std::string& tensor_name) const;
/// \brief Extracts a subgraph constrained by input edges and output edges. In the end
/// the underlying ModelProto is modified - obsolete inputs, initializers, nodes
/// and outputs are removed from the in-memory model.
@ -86,12 +92,25 @@ namespace ngraph
/// \brief Converts an edited ONNX model to an nGraph Function representation.
std::shared_ptr<Function> get_function() const;
/// \brief Returns a list of all inputs of the in-memory model, including initializers.
/// \brief Returns a list of all inputs of the in-memory model.
/// The returned value might depend on the previous operations executed on an
/// instance of the model editor, in particular the subgraph extraction which
/// can discard some inputs and initializers from the original graph.
/// can discard some inputs from the original graph.
///
/// \note ONNX initializers is not treated as input of the model.
std::vector<std::string> model_inputs() const;
/// \brief Returns a list of all outputs of the in-memory model.
/// The returned value might depend on the previous operations executed on an
/// instance of the model editor.
std::vector<std::string> model_outputs() const;
/// \brief Returns true if input edge is input of the model. Otherwise false.
bool is_input(const InputEdge& edge) const;
/// \brief Returns true if output edge is input of the model. Otherwise false.
bool is_output(const OutputEdge& edge) const;
/// \brief Returns the path to the original model file
const std::string& model_path() const;
@ -161,6 +180,12 @@ namespace ngraph
///
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
///
/// \param name The name of tensor in a graph.
///
bool is_correct_tensor_name(const std::string& name) const;
/// \brief Returns a nGraph function based on edited model
/// decoded to framework nodes
///

View File

@ -12,6 +12,7 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "onnx_common/utils.hpp"
#include "onnx_import/core/node.hpp"
#include "utils/common.hpp"
@ -35,7 +36,7 @@ namespace ngraph
if (onnx_tensor.has_shape())
{
m_partial_shape = to_ng_shape(onnx_tensor.shape());
m_partial_shape = onnx_common::to_ng_shape(onnx_tensor.shape());
}
else
{
@ -87,28 +88,6 @@ namespace ngraph
return tensor.get_ng_constant();
}
PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape) const
{
if (onnx_shape.dim_size() == 0)
{
return Shape{}; // empty list of dimensions denotes a scalar
}
std::vector<Dimension> dims;
for (const auto& onnx_dim : onnx_shape.dim())
{
if (onnx_dim.has_dim_value())
{
dims.emplace_back(onnx_dim.dim_value());
}
else // has_dim_param() == true or it is empty dim
{
dims.push_back(Dimension::dynamic());
}
}
return PartialShape{dims};
}
private:
const ONNX_NAMESPACE::ValueInfoProto* m_value_info_proto;
PartialShape m_partial_shape;

View File

@ -256,3 +256,38 @@ bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode&
{
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
}
bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const
{
if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index))
{
return true;
}
if (m_output_consumers_index.find(name) != std::end(m_output_consumers_index))
{
return true;
}
return false;
}
std::string onnx_editor::EdgeMapper::get_input_port_name(const InputEdge& edge) const
{
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_inputs.size()) &&
edge.m_port_idx >= 0 &&
edge.m_port_idx < static_cast<int>(m_node_inputs[edge.m_node_idx].size()))
{
return m_node_inputs[edge.m_node_idx][edge.m_port_idx];
}
return "";
}
std::string onnx_editor::EdgeMapper::get_output_port_name(const OutputEdge& edge) const
{
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_outputs.size()) &&
edge.m_port_idx >= 0 &&
edge.m_port_idx < static_cast<int>(m_node_outputs[edge.m_node_idx].size()))
{
return m_node_outputs[edge.m_node_idx][edge.m_port_idx];
}
return "";
}

View File

@ -35,6 +35,20 @@ namespace
return nullptr;
}
ValueInfoProto* find_graph_output(GraphProto& graph, const std::string& name)
{
for (int i = 0; i < graph.output_size(); ++i)
{
auto* output_desc = graph.mutable_output(i);
if (output_desc->has_name() && output_desc->name() == name)
{
return output_desc;
}
}
return nullptr;
}
TensorProto* find_graph_initializer(GraphProto& graph, const std::string& name)
{
for (int i = 0; i < graph.initializer_size(); ++i)
@ -182,6 +196,31 @@ namespace
tensor_type->set_elem_type(initializer.data_type());
}
}
class InferShapesAutoRelease
{
public:
InferShapesAutoRelease(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto)
: m_model_proto{model_proto}
, m_infer_shapes_was_run{false}
{
}
void infer_shapes()
{
ONNX_NAMESPACE::shape_inference::InferShapes(*m_model_proto);
m_infer_shapes_was_run = true;
}
~InferShapesAutoRelease()
{
if (m_infer_shapes_was_run)
{
m_model_proto->mutable_graph()->clear_value_info();
}
}
private:
std::shared_ptr<ONNX_NAMESPACE::ModelProto> m_model_proto;
bool m_infer_shapes_was_run;
};
} // namespace
/// \brief A helper class used to hold the ModelProto object as its field
@ -198,9 +237,6 @@ struct onnx_editor::ONNXModelEditor::Impl
onnx_common::parse_from_file(model_path))}
{
}
void infer_shapes() { ONNX_NAMESPACE::shape_inference::InferShapes(*m_model_proto.get()); }
void remove_shape_inference_info() { m_model_proto->mutable_graph()->clear_value_info(); }
};
onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path)
@ -274,6 +310,58 @@ void onnx_editor::ONNXModelEditor::set_input_shapes(
}
}
PartialShape onnx_editor::ONNXModelEditor::get_tensor_shape(const std::string& tensor_name) const
{
const ValueInfoProto* value_info = nullptr;
auto* onnx_graph = m_pimpl->m_model_proto->mutable_graph();
InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto);
if (const auto* input = find_graph_input(*onnx_graph, tensor_name))
{
value_info = input;
}
else if (const auto* output = find_graph_output(*onnx_graph, tensor_name))
{
value_info = output;
}
else
{
try
{
onnx_shapes.infer_shapes();
}
catch (const std::exception& e)
{
NGRAPH_WARN << "Cannot replace existing shapes during get_tensor_shape";
return PartialShape::dynamic();
}
auto node_it = std::find_if(std::begin(onnx_graph->value_info()),
std::end(onnx_graph->value_info()),
[&tensor_name](const ValueInfoProto& value_info) -> bool {
return value_info.name() == tensor_name;
});
if (node_it != std::end(onnx_graph->value_info()))
{
value_info = &(*node_it);
}
}
if (value_info != nullptr)
{
const auto& onnx_tensor_type = value_info->type().tensor_type();
if (onnx_tensor_type.has_shape())
{
return onnx_common::to_ng_shape(onnx_tensor_type.shape());
}
else
{
return PartialShape::dynamic();
}
}
else
{
throw ngraph_error("The tensor: " + tensor_name + " was not found in the graph");
}
}
void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdge>& inputs,
const std::vector<OutputEdge>& outputs)
{
@ -282,35 +370,78 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdg
return;
}
m_pimpl->infer_shapes();
InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto);
onnx_shapes.infer_shapes();
SubgraphExtractor editor{*(m_pimpl->m_model_proto->mutable_graph())};
editor.add_new_inputs(inputs);
editor.add_new_outputs(outputs);
editor.extract_subgraph(outputs);
m_pimpl->remove_shape_inference_info();
m_pimpl->m_is_mapper_updated = false;
}
std::vector<std::string> onnx_editor::ONNXModelEditor::model_inputs() const
{
const auto& graph = m_pimpl->m_model_proto->graph();
std::vector<std::string> inputs;
inputs.reserve(graph.input_size() - graph.initializer_size());
for (const auto& in : graph.input())
{
if (std::find_if(graph.initializer().begin(),
graph.initializer().end(),
[&in](const TensorProto& initializer) {
return initializer.name() == in.name();
}) == graph.initializer().end())
{
inputs.push_back(in.name());
}
}
return inputs;
}
std::vector<std::string> inputs_and_initializers;
inputs_and_initializers.reserve(graph.input_size() + graph.initializer_size());
std::vector<std::string> onnx_editor::ONNXModelEditor::model_outputs() const
{
const auto& graph = m_pimpl->m_model_proto->graph();
std::vector<std::string> outputs;
outputs.reserve(graph.output_size());
std::transform(graph.input().begin(),
graph.input().end(),
std::back_inserter(inputs_and_initializers),
std::transform(graph.output().begin(),
graph.output().end(),
std::back_inserter(outputs),
extract_name<ONNX_NAMESPACE::ValueInfoProto>);
std::transform(graph.initializer().begin(),
graph.initializer().end(),
std::back_inserter(inputs_and_initializers),
extract_name<ONNX_NAMESPACE::TensorProto>);
return outputs;
}
return inputs_and_initializers;
bool onnx_editor::ONNXModelEditor::is_input(const InputEdge& edge) const
{
update_mapper_if_needed();
const auto& port_name = m_pimpl->m_edge_mapper.get_input_port_name(edge);
if (port_name.empty())
{
return false;
}
else
{
const auto& inputs = model_inputs();
return std::count(std::begin(inputs), std::end(inputs), port_name) > 0;
}
}
bool onnx_editor::ONNXModelEditor::is_output(const OutputEdge& edge) const
{
update_mapper_if_needed();
const auto& port_name = m_pimpl->m_edge_mapper.get_output_port_name(edge);
if (port_name.empty())
{
return false;
}
else
{
const auto& outputs = model_outputs();
return std::count(std::begin(outputs), std::end(outputs), port_name) > 0;
}
}
std::string onnx_editor::ONNXModelEditor::model_string() const
@ -393,6 +524,12 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
}
bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const
{
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.is_correct_tensor_name(name);
}
std::shared_ptr<Function> onnx_editor::ONNXModelEditor::decode()
{
return onnx_import::detail::decode_to_framework_nodes(m_pimpl->m_model_proto, m_model_path);

View File

@ -0,0 +1,551 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import onnx
import pytest
from onnx.helper import make_graph, make_model, make_tensor_value_info
from ngraph import PartialShape
from ngraph.frontend import FrontEndManager
# in1 in2 in3
# | | |
# \ / |
# +--------+ +------+
# | Add | | Relu |
# +--------+ +------+
# <add_out> |
# / \\ |
# +--------+ +-----+ out3
# | Split | | Mul |
# |(split1)|..| |
# +--------+ +-----+
# / \ |
# out1 out2 out4
#
def create_test_onnx_models():
models = {}
# Input model
add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"])
split = onnx.helper.make_node("Split", inputs=["add_out"],
outputs=["out1", "out2"], name="split1", axis=0)
relu = onnx.helper.make_node("Relu", inputs=["in3"], outputs=["out3"])
mul = onnx.helper.make_node("Mul", inputs=["add_out", "add_out"], outputs=["out4"])
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for extract_subgraph
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add], "test_graph", input_tensors, output_tensors)
models["extract_subgraph.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for extract_subgraph 2
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for extract_subgraph 3
input_tensors = [
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (2, 2)),
]
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
outputs=["out1", "out2"], name="split1", axis=0)
graph = make_graph([expected_split], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_3.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for extract_subgraph 4
input_tensors = [
make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
]
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
outputs=["out1", "out2"])
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
outputs=["out4"])
graph = make_graph([expected_split, expected_mul], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for test_override_all_outputs
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
models["test_override_all_outputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for test_override_all_outputs 2
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add, mul], "test_graph", input_tensors, output_tensors)
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
# Expected for test_override_all_inputs
input_tensors = [
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
]
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
outputs=["out1", "out2"])
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
outputs=["out4"])
graph = make_graph([expected_split, relu, expected_mul], "test_graph", input_tensors, output_tensors)
models["test_override_all_inputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
# test partial shape
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (8, 16)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (8, 16)),
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (4, 6)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (4, 16)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (4, 16)),
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (4, 6)),
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (8, 16)),
]
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer")
return models
fem = FrontEndManager()
test_models_names = []
def setup_module():
models = create_test_onnx_models()
for name, model in models.items():
onnx.save_model(model, name)
test_models_names.append(name)
def teardown_module():
for name in test_models_names:
os.remove(name)
def skip_if_onnx_frontend_is_disabled():
front_ends = fem.get_available_front_ends()
if "onnx" not in front_ends:
pytest.skip()
# Function to compare ng Functions (ops names, types and shapes).
# Note that the functions uses get_ordered_ops, so the topological order of ops should be also preserved.
def compare_functions(current, expected): # noqa: C901 the function is too complex
result = True
msg = ""
if current.get_friendly_name() != expected.get_friendly_name():
result = False
msg += "Friendly name of nG Functions not equal. "
msg += f"Current: {current.get_friendly_name()}, expected: {expected.get_friendly_name()}. "
current_ops = current.get_ordered_ops()
expected_ops = expected.get_ordered_ops()
if len(current_ops) != len(expected_ops):
result = False
msg += "Not equal number of ops. "
msg += f"Current: {len(current_ops)}, expected: {len(expected_ops)}. "
for i in range(len(current_ops)):
if (current_ops[i].get_friendly_name() != expected_ops[i].get_friendly_name()
and current_ops[i].get_type_name() != "Constant"): # const have different names
result = False
msg += "Not equal op name. "
msg += f"Current: {current_ops[i].get_friendly_name()}, "
msg += f"expected: {expected_ops[i].get_friendly_name()}. "
if current_ops[i].get_output_size() != expected_ops[i].get_output_size():
result = False
msg += f"Not equal output size of {current_ops[i].get_friendly_name()}. "
for j in range(current_ops[i].get_output_size()):
if current_ops[i].get_output_partial_shape(j) != expected_ops[i].get_output_partial_shape(j):
result = False
msg += f"Not equal op partial shapes of {current_ops[i].get_friendly_name()}. "
msg += f"Current: {current_ops[i].get_partial_shape({j})}, "
msg += f"expected: {expected_ops[i].get_partial_shape({j})}. "
if current_ops[i].get_output_element_type(j) != expected_ops[i].get_output_element_type(j):
result = False
msg += f"Not equal output element type of {current_ops[i].get_friendly_name()}. "
msg += f"Current: {current_ops[i].get_output_element_type(j)}, "
msg += f"expected: {expected_ops[i].get_output_element_type(j)}. "
if not result:
print(msg)
return result
def test_extract_subgraph():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out").get_input_port(inputPortIndex=0) # in1
place2 = model.get_place_by_tensor_name(tensorName="add_out").get_input_port(inputPortIndex=1) # in2
place3 = model.get_place_by_tensor_name(tensorName="add_out")
model.extract_subgraph(inputs=[place1, place2], outputs=[place3])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_extract_subgraph_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
place2 = model.get_place_by_tensor_name(tensorName="out3")
model.extract_subgraph(inputs=[], outputs=[place1, place2])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph_2.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_extract_subgraph_3():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out1")
place3 = model.get_place_by_tensor_name(tensorName="out2")
model.extract_subgraph(inputs=[place1], outputs=[place2, place3])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph_3.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_extract_subgraph_4():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_tensor_name(tensorName="out2")
place6 = model.get_place_by_tensor_name(tensorName="out4")
model.extract_subgraph(inputs=[place1, place2, place3], outputs=[place4, place5, place6])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph_4.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_outputs():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
place2 = model.get_place_by_tensor_name(tensorName="out3")
model.override_all_outputs(outputs=[place1, place2])
result_func = fe.convert(model)
expected_model = fe.load("test_override_all_outputs.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_outputs_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="out4")
model.override_all_outputs(outputs=[place1])
result_func = fe.convert(model)
expected_model = fe.load("test_override_all_outputs_2.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_inputs():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
place4 = model.get_place_by_tensor_name(tensorName="in3")
model.override_all_inputs(inputs=[place1, place2, place3, place4])
result_func = fe.convert(model)
expected_model = fe.load("test_override_all_inputs.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_inputs_exceptions():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
place2 = model.get_place_by_tensor_name(tensorName="in2")
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="in3")
with pytest.raises(Exception) as e:
model.override_all_inputs(inputs=[place1, place2])
assert "Unexpected number of inputs after override_all_inputs" in str(e)
with pytest.raises(Exception) as e:
model.override_all_inputs(inputs=[place3, place4])
assert "Unexpected number of inputs after override_all_inputs" in str(e)
def test_is_input_output():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in2")
assert place1.is_input()
assert not place1.is_output()
place2 = model.get_place_by_tensor_name(tensorName="out2")
assert not place2.is_input()
assert place2.is_output()
place3 = model.get_place_by_tensor_name(tensorName="add_out")
assert not place3.is_input()
assert not place3.is_output()
place4 = place1 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0)
assert not place4.is_input()
assert not place4.is_output()
def test_set_partial_shape():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
model.set_partial_shape(place1, PartialShape([8, 16]))
place2 = model.get_place_by_tensor_name(tensorName="in2")
model.set_partial_shape(place2, PartialShape([8, 16]))
place3 = model.get_place_by_tensor_name(tensorName="in3")
model.set_partial_shape(place3, PartialShape([4, 6]))
result_func = fe.convert(model)
expected_model = fe.load("test_partial_shape.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_get_partial_shape():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
assert model.get_partial_shape(place1) == PartialShape([2, 2])
place2 = model.get_place_by_tensor_name(tensorName="out1")
assert model.get_partial_shape(place2) == PartialShape([1, 2])
place3 = model.get_place_by_tensor_name(tensorName="add_out")
assert model.get_partial_shape(place3) == PartialShape([2, 2])
place4 = model.get_place_by_tensor_name(tensorName="in3")
model.set_partial_shape(place4, PartialShape([4, 6]))
assert model.get_partial_shape(place4) == PartialShape([4, 6])
assert model.get_partial_shape(place2) == PartialShape([1, 2])
def test_get_inputs():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
in_names = [place.get_names()[0] for place in model.get_inputs()]
assert in_names == ["in1", "in2", "in3"]
def test_get_outputs():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
out_names = [place.get_names()[0] for place in model.get_outputs()]
assert out_names == ["out1", "out2", "out3", "out4"]
def test_is_equal():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
assert place1.is_equal(place1)
place2 = model.get_place_by_tensor_name(tensorName="out2")
assert place2.is_equal(place2)
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
assert place3.is_equal(place4)
place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place6 = model.get_place_by_tensor_name(tensorName="out1").get_input_port(inputPortIndex=0)
assert place5.is_equal(place6)
place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port()
assert place7.is_equal(place7)
place8 = model.get_place_by_tensor_name(tensorName="add_out")
assert place8.is_equal(place8)
assert not place1.is_equal(place2)
assert not place6.is_equal(place7)
assert not place8.is_equal(place2)
def test_get_place_by_tensor_name():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework="onnx")
assert fe
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="out2")
assert place1
place2 = model.get_place_by_tensor_name(tensorName="add_out")
assert place2
place3 = model.get_place_by_tensor_name(tensorName="in1")
assert place3
with pytest.raises(Exception) as e:
model.get_place_by_tensor_name(tensorName="0:add_out")
assert "The tensor with name: 0:add_out does not exist in the graph" in str(e)

View File

@ -1462,3 +1462,110 @@ NGRAPH_TEST(onnx_editor, cut_operator_with_no_schema)
EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, is_model_input)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
EXPECT_TRUE(editor.is_input(InputEdge{0, 0}));
const auto edge1 = editor.find_input_edge(EditorOutput{"add1"}, 1);
EXPECT_TRUE(editor.is_input(edge1));
EXPECT_FALSE(editor.is_input(InputEdge{1, 2}));
EXPECT_FALSE(editor.is_input(InputEdge{3, 0}));
EXPECT_FALSE(editor.is_input(InputEdge{11, 0}));
const auto edge2 = editor.find_input_edge(EditorOutput{"conv1"}, 2);
EXPECT_FALSE(editor.is_input(edge2));
EXPECT_FALSE(editor.is_input(InputEdge{2, 1})); // initializer is not treated as input
const auto edge3 = editor.find_input_edge(EditorOutput{"conv1"}, EditorInput{"in4"});
EXPECT_FALSE(editor.is_input(edge3));
}
NGRAPH_TEST(onnx_editor, is_model_output)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
EXPECT_TRUE(editor.is_output(OutputEdge{4, 0}));
EXPECT_TRUE(editor.is_output(OutputEdge{5, 1}));
const auto edge1 = editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{"split2"});
EXPECT_TRUE(editor.is_output(edge1));
EXPECT_FALSE(editor.is_output(OutputEdge{4, 1}));
EXPECT_FALSE(editor.is_output(OutputEdge{0, 0}));
EXPECT_FALSE(editor.is_output(OutputEdge{11, 0}));
const auto edge2 = editor.find_output_edge("add2");
EXPECT_FALSE(editor.is_output(edge2));
}
NGRAPH_TEST(onnx_editor, model_inputs)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const auto inputs = editor.model_inputs();
EXPECT_TRUE(inputs == (std::vector<std::string>{"in1", "in2", "in3"})); // in4 is initializer
}
NGRAPH_TEST(onnx_editor, model_output)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const auto outputs = editor.model_outputs();
EXPECT_TRUE(outputs == (std::vector<std::string>{"mul1", "split2", "mul2"}));
}
NGRAPH_TEST(onnx_editor, get_tensor_shape)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
EXPECT_EQ(editor.get_tensor_shape("mul2"), (PartialShape{1, 1, 2, 2}));
EXPECT_EQ(editor.get_tensor_shape("in1"), (PartialShape{2, 2}));
EXPECT_EQ(editor.get_tensor_shape("in2"), (PartialShape{}));
EXPECT_EQ(editor.get_tensor_shape("in3"), (PartialShape{1, 1, 2, 2}));
EXPECT_EQ(editor.get_tensor_shape("relu1"), (PartialShape{2, 2}));
EXPECT_EQ(editor.get_tensor_shape("add1"), (PartialShape{2, 2}));
try
{
editor.get_tensor_shape("not_existed");
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(
msg.find("The tensor: not_existed was not found in the graph") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, get_tensor_shape_after_modification)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
EXPECT_EQ(editor.get_tensor_shape("in3"), (PartialShape{1, 1, 2, 2}));
EXPECT_EQ(editor.get_tensor_shape("conv1"), (PartialShape{1, 1, 2, 2}));
EXPECT_EQ(editor.get_tensor_shape("mul2"), (PartialShape{1, 1, 2, 2}));
editor.set_input_shapes({{"in3", (PartialShape{1, 1, 4, 4})}});
EXPECT_EQ(editor.get_tensor_shape("conv1"), (PartialShape{1, 1, 4, 4}));
EXPECT_EQ(editor.get_tensor_shape("in3"), (PartialShape{1, 1, 4, 4}));
}
NGRAPH_TEST(onnx_editor, is_correct_tensor_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
EXPECT_TRUE(editor.is_correct_tensor_name("in1"));
EXPECT_TRUE(editor.is_correct_tensor_name("relu1"));
EXPECT_TRUE(editor.is_correct_tensor_name("split2"));
EXPECT_TRUE(editor.is_correct_tensor_name("mul2"));
EXPECT_TRUE(editor.is_correct_tensor_name("in4"));
EXPECT_FALSE(editor.is_correct_tensor_name("relu1_name"));
EXPECT_FALSE(editor.is_correct_tensor_name("not_existed"));
EXPECT_FALSE(editor.is_correct_tensor_name(""));
}