Integration ONNX Editor with FE API (#6773)
This commit is contained in:
parent
9acedbdacf
commit
dc5f44e929
@ -277,7 +277,7 @@ void InputModel::set_partial_shape(Place::Ptr place, const ngraph::PartialShape&
|
||||
|
||||
ngraph::PartialShape InputModel::get_partial_shape(Place::Ptr place) const
|
||||
{
|
||||
FRONT_END_NOT_IMPLEMENTED(set_partial_shape);
|
||||
FRONT_END_NOT_IMPLEMENTED(get_partial_shape);
|
||||
}
|
||||
|
||||
void InputModel::set_element_type(Place::Ptr place, const ngraph::element::Type&)
|
||||
|
@ -2,55 +2,143 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "input_model.hpp"
|
||||
#include <frontend_manager/frontend_exceptions.hpp>
|
||||
#include <input_model.hpp>
|
||||
#include <place.hpp>
|
||||
#include "place.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::frontend;
|
||||
|
||||
InputModelONNX::InputModelONNX(const std::string& path)
|
||||
: m_editor(path)
|
||||
: m_editor{std::make_shared<onnx_editor::ONNXModelEditor>(path)}
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> InputModelONNX::get_inputs() const
|
||||
{
|
||||
auto inputs = m_editor.model_inputs();
|
||||
std::vector<Place::Ptr> ret;
|
||||
ret.reserve(inputs.size());
|
||||
const auto& inputs = m_editor->model_inputs();
|
||||
std::vector<Place::Ptr> in_places;
|
||||
in_places.reserve(inputs.size());
|
||||
for (const auto& input : inputs)
|
||||
{
|
||||
ret.push_back(std::make_shared<PlaceTensorONNX>(input, m_editor));
|
||||
in_places.push_back(std::make_shared<PlaceTensorONNX>(input, m_editor));
|
||||
}
|
||||
return ret;
|
||||
return in_places;
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> InputModelONNX::get_outputs() const
|
||||
{
|
||||
const auto& outputs = m_editor->model_outputs();
|
||||
std::vector<Place::Ptr> out_places;
|
||||
out_places.reserve(outputs.size());
|
||||
for (const auto& output : outputs)
|
||||
{
|
||||
out_places.push_back(std::make_shared<PlaceTensorONNX>(output, m_editor));
|
||||
}
|
||||
return out_places;
|
||||
}
|
||||
|
||||
Place::Ptr InputModelONNX::get_place_by_tensor_name(const std::string& tensor_name) const
|
||||
{
|
||||
NGRAPH_CHECK(m_editor->is_correct_tensor_name(tensor_name),
|
||||
"The tensor with name: " + tensor_name + " does not exist in the graph");
|
||||
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
|
||||
}
|
||||
|
||||
Place::Ptr
|
||||
InputModelONNX::get_place_by_operation_name_and_input_port(const std::string& operation_name,
|
||||
int input_port_index)
|
||||
{
|
||||
const auto edge =
|
||||
m_editor->find_input_edge(onnx_editor::EditorNode(operation_name), input_port_index);
|
||||
return std::make_shared<PlaceInputEdgeONNX>(edge, m_editor);
|
||||
}
|
||||
|
||||
void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape)
|
||||
{
|
||||
std::map<std::string, ngraph::PartialShape> m;
|
||||
m[place->get_names()[0]] = shape;
|
||||
m_editor.set_input_shapes(m);
|
||||
m_editor->set_input_shapes(m);
|
||||
}
|
||||
|
||||
ngraph::PartialShape InputModelONNX::get_partial_shape(Place::Ptr place) const
|
||||
{
|
||||
return m_editor->get_tensor_shape(place->get_names().at(0));
|
||||
}
|
||||
|
||||
void InputModelONNX::set_element_type(Place::Ptr place, const ngraph::element::Type& type)
|
||||
{
|
||||
std::map<std::string, ngraph::element::Type_t> m;
|
||||
m[place->get_names()[0]] = type;
|
||||
m_editor.set_input_types(m);
|
||||
m_editor->set_input_types(m);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> InputModelONNX::decode()
|
||||
{
|
||||
return m_editor.decode();
|
||||
return m_editor->decode();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> InputModelONNX::convert()
|
||||
{
|
||||
return m_editor.get_function();
|
||||
return m_editor->get_function();
|
||||
}
|
||||
|
||||
// Editor features
|
||||
void InputModelONNX::override_all_outputs(const std::vector<Place::Ptr>& outputs)
|
||||
{
|
||||
extract_subgraph({}, outputs);
|
||||
NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(),
|
||||
"Unexpected number of outputs after override_all_outputs");
|
||||
NGRAPH_CHECK(std::all_of(std::begin(outputs),
|
||||
std::end(outputs),
|
||||
[](const Place::Ptr& place) { return place->is_output(); }),
|
||||
"Not all provided arguments of override_all_outputs are new outputs of the model");
|
||||
}
|
||||
|
||||
void InputModelONNX::override_all_inputs(const std::vector<Place::Ptr>& inputs)
|
||||
{
|
||||
const auto outputs_before_extraction = m_editor->model_outputs();
|
||||
extract_subgraph({inputs}, {});
|
||||
NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction),
|
||||
std::end(outputs_before_extraction),
|
||||
std::begin(m_editor->model_outputs())),
|
||||
"All outputs should be preserved after override_all_inputs. Provided inputs does "
|
||||
"not satisfy all outputs");
|
||||
NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(),
|
||||
"Unexpected number of inputs after override_all_inputs");
|
||||
}
|
||||
|
||||
void InputModelONNX::extract_subgraph(const std::vector<Place::Ptr>& inputs,
|
||||
const std::vector<Place::Ptr>& outputs)
|
||||
{
|
||||
std::vector<onnx_editor::InputEdge> onnx_inputs;
|
||||
onnx_inputs.reserve(inputs.size());
|
||||
for (const auto& input : inputs)
|
||||
{
|
||||
if (const auto input_port = std::dynamic_pointer_cast<PlaceInputEdgeONNX>(input))
|
||||
{
|
||||
onnx_inputs.push_back(input_port->get_input_edge());
|
||||
}
|
||||
else if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(input))
|
||||
{
|
||||
auto name = tensor->get_names()[0];
|
||||
const auto consumers = m_editor->find_output_consumers(name);
|
||||
std::transform(std::begin(consumers),
|
||||
std::end(consumers),
|
||||
std::back_inserter(onnx_inputs),
|
||||
[](const onnx_editor::InputEdge& edge) { return edge; });
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<onnx_editor::OutputEdge> onnx_outputs;
|
||||
onnx_outputs.reserve(outputs.size());
|
||||
for (const auto& output : outputs)
|
||||
{
|
||||
const auto output_port = output->get_producing_port();
|
||||
const auto onnx_output_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output_port);
|
||||
NGRAPH_CHECK(onnx_output_edge,
|
||||
"Non-onnx output place was passed as extraction subgraph argument");
|
||||
onnx_outputs.push_back(onnx_output_edge->get_output_edge());
|
||||
}
|
||||
m_editor->cut_graph_fragment(onnx_inputs, onnx_outputs);
|
||||
}
|
||||
|
@ -17,15 +17,25 @@ namespace ngraph
|
||||
InputModelONNX(const std::string& path);
|
||||
|
||||
std::vector<Place::Ptr> get_inputs() const override;
|
||||
std::vector<Place::Ptr> get_outputs() const override;
|
||||
Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override;
|
||||
Place::Ptr get_place_by_operation_name_and_input_port(const std::string& operation_name,
|
||||
int input_port_index) override;
|
||||
void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override;
|
||||
ngraph::PartialShape get_partial_shape(Place::Ptr place) const override;
|
||||
void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override;
|
||||
|
||||
std::shared_ptr<Function> decode();
|
||||
std::shared_ptr<Function> convert();
|
||||
|
||||
// Editor features
|
||||
void override_all_outputs(const std::vector<Place::Ptr>& outputs) override;
|
||||
void override_all_inputs(const std::vector<Place::Ptr>& inputs) override;
|
||||
void extract_subgraph(const std::vector<Place::Ptr>& inputs,
|
||||
const std::vector<Place::Ptr>& outputs) override;
|
||||
|
||||
private:
|
||||
onnx_editor::ONNXModelEditor m_editor;
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
};
|
||||
|
||||
} // namespace frontend
|
||||
|
134
ngraph/frontend/onnx/frontend/src/place.cpp
Normal file
134
ngraph/frontend/onnx/frontend/src/place.cpp
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "place.hpp"
|
||||
#include <frontend_manager/frontend_exceptions.hpp>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::frontend;
|
||||
|
||||
PlaceInputEdgeONNX::PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_edge{edge}
|
||||
, m_editor{editor}
|
||||
{
|
||||
}
|
||||
|
||||
onnx_editor::InputEdge PlaceInputEdgeONNX::get_input_edge() const
|
||||
{
|
||||
return m_edge;
|
||||
}
|
||||
|
||||
bool PlaceInputEdgeONNX::is_input() const
|
||||
{
|
||||
return m_editor->is_input(m_edge);
|
||||
}
|
||||
|
||||
bool PlaceInputEdgeONNX::is_output() const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlaceInputEdgeONNX::is_equal(Place::Ptr another) const
|
||||
{
|
||||
if (const auto in_edge = std::dynamic_pointer_cast<PlaceInputEdgeONNX>(another))
|
||||
{
|
||||
const auto& editor_edge = in_edge->get_input_edge();
|
||||
return (editor_edge.m_node_idx == m_edge.m_node_idx) &&
|
||||
(editor_edge.m_port_idx == m_edge.m_port_idx);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_edge{edge}
|
||||
, m_editor{editor}
|
||||
{
|
||||
}
|
||||
|
||||
onnx_editor::OutputEdge PlaceOutputEdgeONNX::get_output_edge() const
|
||||
{
|
||||
return m_edge;
|
||||
}
|
||||
|
||||
bool PlaceOutputEdgeONNX::is_input() const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlaceOutputEdgeONNX::is_output() const
|
||||
{
|
||||
return m_editor->is_output(m_edge);
|
||||
}
|
||||
|
||||
bool PlaceOutputEdgeONNX::is_equal(Place::Ptr another) const
|
||||
{
|
||||
if (const auto out_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(another))
|
||||
{
|
||||
const auto& editor_edge = out_edge->get_output_edge();
|
||||
return (editor_edge.m_node_idx == m_edge.m_node_idx) &&
|
||||
(editor_edge.m_port_idx == m_edge.m_port_idx);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
PlaceTensorONNX::PlaceTensorONNX(const std::string& name,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
|
||||
: m_name(name)
|
||||
, m_editor(editor)
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<std::string> PlaceTensorONNX::get_names() const
|
||||
{
|
||||
return {m_name};
|
||||
}
|
||||
|
||||
Place::Ptr PlaceTensorONNX::get_producing_port() const
|
||||
{
|
||||
return std::make_shared<PlaceOutputEdgeONNX>(m_editor->find_output_edge(m_name), m_editor);
|
||||
}
|
||||
|
||||
std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_ports() const
|
||||
{
|
||||
std::vector<Place::Ptr> ret;
|
||||
auto edges = m_editor->find_output_consumers(m_name);
|
||||
std::transform(edges.begin(),
|
||||
edges.end(),
|
||||
std::back_inserter(ret),
|
||||
[this](const onnx_editor::InputEdge& edge) {
|
||||
return std::make_shared<PlaceInputEdgeONNX>(edge, this->m_editor);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const
|
||||
{
|
||||
return std::make_shared<PlaceInputEdgeONNX>(
|
||||
m_editor->find_input_edge(onnx_editor::EditorOutput(m_name),
|
||||
onnx_editor::EditorInput(input_port_index)),
|
||||
m_editor);
|
||||
}
|
||||
|
||||
bool PlaceTensorONNX::is_input() const
|
||||
{
|
||||
const auto inputs = m_editor->model_inputs();
|
||||
return std::find(std::begin(inputs), std::end(inputs), m_name) != std::end(inputs);
|
||||
}
|
||||
|
||||
bool PlaceTensorONNX::is_output() const
|
||||
{
|
||||
const auto outputs = m_editor->model_outputs();
|
||||
return std::find(std::begin(outputs), std::end(outputs), m_name) != std::end(outputs);
|
||||
}
|
||||
|
||||
bool PlaceTensorONNX::is_equal(Place::Ptr another) const
|
||||
{
|
||||
if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(another))
|
||||
{
|
||||
return m_name == tensor->get_names().at(0);
|
||||
}
|
||||
return false;
|
||||
}
|
@ -4,7 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <frontend_manager/place.hpp>
|
||||
#include <onnx_editor/editor.hpp>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -13,65 +15,63 @@ namespace ngraph
|
||||
class PlaceInputEdgeONNX : public Place
|
||||
{
|
||||
public:
|
||||
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge)
|
||||
: m_edge(edge)
|
||||
{
|
||||
}
|
||||
PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
|
||||
onnx_editor::InputEdge get_input_edge() const;
|
||||
|
||||
bool is_input() const override;
|
||||
|
||||
bool is_output() const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
private:
|
||||
onnx_editor::InputEdge m_edge;
|
||||
const std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
};
|
||||
|
||||
class PlaceOutputEdgeONNX : public Place
|
||||
{
|
||||
public:
|
||||
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge)
|
||||
: m_edge(edge)
|
||||
{
|
||||
}
|
||||
PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
|
||||
onnx_editor::OutputEdge get_output_edge() const;
|
||||
|
||||
bool is_input() const override;
|
||||
|
||||
bool is_output() const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
private:
|
||||
onnx_editor::OutputEdge m_edge;
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
};
|
||||
|
||||
class PlaceTensorONNX : public Place
|
||||
{
|
||||
public:
|
||||
PlaceTensorONNX(const std::string& name, const onnx_editor::ONNXModelEditor& editor)
|
||||
: m_name(name)
|
||||
, m_editor(editor)
|
||||
{
|
||||
}
|
||||
PlaceTensorONNX(const std::string& name, std::shared_ptr<onnx_editor::ONNXModelEditor> editor);
|
||||
|
||||
std::vector<std::string> get_names() const override { return {m_name}; }
|
||||
std::vector<std::string> get_names() const override;
|
||||
|
||||
Place::Ptr get_producing_port() const override
|
||||
{
|
||||
return std::make_shared<PlaceOutputEdgeONNX>(m_editor.find_output_edge(m_name));
|
||||
}
|
||||
Place::Ptr get_producing_port() const override;
|
||||
|
||||
std::vector<Place::Ptr> get_consuming_ports() const override
|
||||
{
|
||||
std::vector<Place::Ptr> ret;
|
||||
auto edges = m_editor.find_output_consumers(m_name);
|
||||
std::transform(edges.begin(),
|
||||
edges.end(),
|
||||
std::back_inserter(ret),
|
||||
[](const onnx_editor::InputEdge& edge) {
|
||||
return std::make_shared<PlaceInputEdgeONNX>(edge);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
std::vector<Place::Ptr> get_consuming_ports() const override;
|
||||
|
||||
Ptr get_input_port(int input_port_index) const override
|
||||
{
|
||||
return std::make_shared<PlaceInputEdgeONNX>(m_editor.find_input_edge(
|
||||
onnx_editor::EditorNode(m_name), onnx_editor::EditorInput(input_port_index)));
|
||||
}
|
||||
Ptr get_input_port(int input_port_index) const override;
|
||||
|
||||
bool is_input() const override;
|
||||
|
||||
bool is_output() const override;
|
||||
|
||||
bool is_equal(Place::Ptr another) const override;
|
||||
|
||||
private:
|
||||
std::string m_name;
|
||||
const onnx_editor::ONNXModelEditor& m_editor;
|
||||
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
|
||||
};
|
||||
} // namespace frontend
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "ngraph/partial_shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
namespace ONNX_NAMESPACE
|
||||
@ -38,5 +39,11 @@ namespace ngraph
|
||||
///
|
||||
bool is_supported_ng_type(const element::Type_t& ng_type);
|
||||
|
||||
/// \brief Retuns nG PartialShape based on onnx_shape.
|
||||
///
|
||||
/// \param onnx_shape A shape of tensor represented in ONNX way.
|
||||
///
|
||||
PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape);
|
||||
|
||||
} // namespace onnx_common
|
||||
} // namespace ngraph
|
||||
|
@ -88,5 +88,27 @@ namespace ngraph
|
||||
return NG_2_ONNX_TYPES.count(ng_type) > 0;
|
||||
}
|
||||
|
||||
PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
|
||||
{
|
||||
if (onnx_shape.dim_size() == 0)
|
||||
{
|
||||
return Shape{}; // empty list of dimensions denotes a scalar
|
||||
}
|
||||
|
||||
std::vector<Dimension> dims;
|
||||
for (const auto& onnx_dim : onnx_shape.dim())
|
||||
{
|
||||
if (onnx_dim.has_dim_value())
|
||||
{
|
||||
dims.emplace_back(onnx_dim.dim_value());
|
||||
}
|
||||
else // has_dim_param() == true or it is empty dim
|
||||
{
|
||||
dims.push_back(Dimension::dynamic());
|
||||
}
|
||||
}
|
||||
return PartialShape{dims};
|
||||
}
|
||||
|
||||
} // namespace onnx_common
|
||||
} // namespace ngraph
|
||||
|
@ -99,6 +99,24 @@ namespace ngraph
|
||||
///
|
||||
ONNX_IMPORTER_API bool is_correct_and_unambiguous_node(const EditorNode& node) const;
|
||||
|
||||
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
|
||||
///
|
||||
/// \param name The name of tensor in a graph.
|
||||
///
|
||||
bool is_correct_tensor_name(const std::string& name) const;
|
||||
|
||||
/// \brief Get name of input port indicated by the input edge.
|
||||
///
|
||||
/// \note Empty string is returned if the port name is not found.
|
||||
///
|
||||
std::string get_input_port_name(const InputEdge& edge) const;
|
||||
|
||||
/// \brief Get name of output port indicated by the input edge.
|
||||
///
|
||||
/// \note Empty string is returned if the port name is not found.
|
||||
///
|
||||
std::string get_output_port_name(const OutputEdge& edge) const;
|
||||
|
||||
private:
|
||||
std::vector<int> find_node_indexes(const std::string& node_name,
|
||||
const std::string& output_name) const;
|
||||
|
@ -53,6 +53,12 @@ namespace ngraph
|
||||
/// the inputs specified in its parameter.
|
||||
void set_input_shapes(const std::map<std::string, ngraph::PartialShape>& input_shapes);
|
||||
|
||||
/// \brief Get shape of ONNX tensor indicated by the tensor_name.
|
||||
///
|
||||
/// \param tensor_name The name of ONNX tensor.
|
||||
///
|
||||
PartialShape get_tensor_shape(const std::string& tensor_name) const;
|
||||
|
||||
/// \brief Extracts a subgraph constrained by input edges and output edges. In the end
|
||||
/// the underlying ModelProto is modified - obsolete inputs, initializers, nodes
|
||||
/// and outputs are removed from the in-memory model.
|
||||
@ -86,12 +92,25 @@ namespace ngraph
|
||||
/// \brief Converts an edited ONNX model to an nGraph Function representation.
|
||||
std::shared_ptr<Function> get_function() const;
|
||||
|
||||
/// \brief Returns a list of all inputs of the in-memory model, including initializers.
|
||||
/// \brief Returns a list of all inputs of the in-memory model.
|
||||
/// The returned value might depend on the previous operations executed on an
|
||||
/// instance of the model editor, in particular the subgraph extraction which
|
||||
/// can discard some inputs and initializers from the original graph.
|
||||
/// can discard some inputs from the original graph.
|
||||
///
|
||||
/// \note ONNX initializers is not treated as input of the model.
|
||||
std::vector<std::string> model_inputs() const;
|
||||
|
||||
/// \brief Returns a list of all outputs of the in-memory model.
|
||||
/// The returned value might depend on the previous operations executed on an
|
||||
/// instance of the model editor.
|
||||
std::vector<std::string> model_outputs() const;
|
||||
|
||||
/// \brief Returns true if input edge is input of the model. Otherwise false.
|
||||
bool is_input(const InputEdge& edge) const;
|
||||
|
||||
/// \brief Returns true if output edge is input of the model. Otherwise false.
|
||||
bool is_output(const OutputEdge& edge) const;
|
||||
|
||||
/// \brief Returns the path to the original model file
|
||||
const std::string& model_path() const;
|
||||
|
||||
@ -161,6 +180,12 @@ namespace ngraph
|
||||
///
|
||||
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
|
||||
|
||||
/// \brief Returns true if a provided tensor name is correct (exists in a graph).
|
||||
///
|
||||
/// \param name The name of tensor in a graph.
|
||||
///
|
||||
bool is_correct_tensor_name(const std::string& name) const;
|
||||
|
||||
/// \brief Returns a nGraph function based on edited model
|
||||
/// decoded to framework nodes
|
||||
///
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/partial_shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "onnx_common/utils.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
#include "utils/common.hpp"
|
||||
|
||||
@ -35,7 +36,7 @@ namespace ngraph
|
||||
|
||||
if (onnx_tensor.has_shape())
|
||||
{
|
||||
m_partial_shape = to_ng_shape(onnx_tensor.shape());
|
||||
m_partial_shape = onnx_common::to_ng_shape(onnx_tensor.shape());
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -87,28 +88,6 @@ namespace ngraph
|
||||
return tensor.get_ng_constant();
|
||||
}
|
||||
|
||||
PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape) const
|
||||
{
|
||||
if (onnx_shape.dim_size() == 0)
|
||||
{
|
||||
return Shape{}; // empty list of dimensions denotes a scalar
|
||||
}
|
||||
|
||||
std::vector<Dimension> dims;
|
||||
for (const auto& onnx_dim : onnx_shape.dim())
|
||||
{
|
||||
if (onnx_dim.has_dim_value())
|
||||
{
|
||||
dims.emplace_back(onnx_dim.dim_value());
|
||||
}
|
||||
else // has_dim_param() == true or it is empty dim
|
||||
{
|
||||
dims.push_back(Dimension::dynamic());
|
||||
}
|
||||
}
|
||||
return PartialShape{dims};
|
||||
}
|
||||
|
||||
private:
|
||||
const ONNX_NAMESPACE::ValueInfoProto* m_value_info_proto;
|
||||
PartialShape m_partial_shape;
|
||||
|
@ -256,3 +256,38 @@ bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode&
|
||||
{
|
||||
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
|
||||
}
|
||||
|
||||
bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const
|
||||
{
|
||||
if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if (m_output_consumers_index.find(name) != std::end(m_output_consumers_index))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string onnx_editor::EdgeMapper::get_input_port_name(const InputEdge& edge) const
|
||||
{
|
||||
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_inputs.size()) &&
|
||||
edge.m_port_idx >= 0 &&
|
||||
edge.m_port_idx < static_cast<int>(m_node_inputs[edge.m_node_idx].size()))
|
||||
{
|
||||
return m_node_inputs[edge.m_node_idx][edge.m_port_idx];
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string onnx_editor::EdgeMapper::get_output_port_name(const OutputEdge& edge) const
|
||||
{
|
||||
if (edge.m_node_idx >= 0 && edge.m_node_idx < static_cast<int>(m_node_outputs.size()) &&
|
||||
edge.m_port_idx >= 0 &&
|
||||
edge.m_port_idx < static_cast<int>(m_node_outputs[edge.m_node_idx].size()))
|
||||
{
|
||||
return m_node_outputs[edge.m_node_idx][edge.m_port_idx];
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
@ -35,6 +35,20 @@ namespace
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ValueInfoProto* find_graph_output(GraphProto& graph, const std::string& name)
|
||||
{
|
||||
for (int i = 0; i < graph.output_size(); ++i)
|
||||
{
|
||||
auto* output_desc = graph.mutable_output(i);
|
||||
if (output_desc->has_name() && output_desc->name() == name)
|
||||
{
|
||||
return output_desc;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TensorProto* find_graph_initializer(GraphProto& graph, const std::string& name)
|
||||
{
|
||||
for (int i = 0; i < graph.initializer_size(); ++i)
|
||||
@ -182,6 +196,31 @@ namespace
|
||||
tensor_type->set_elem_type(initializer.data_type());
|
||||
}
|
||||
}
|
||||
class InferShapesAutoRelease
|
||||
{
|
||||
public:
|
||||
InferShapesAutoRelease(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto)
|
||||
: m_model_proto{model_proto}
|
||||
, m_infer_shapes_was_run{false}
|
||||
{
|
||||
}
|
||||
void infer_shapes()
|
||||
{
|
||||
ONNX_NAMESPACE::shape_inference::InferShapes(*m_model_proto);
|
||||
m_infer_shapes_was_run = true;
|
||||
}
|
||||
~InferShapesAutoRelease()
|
||||
{
|
||||
if (m_infer_shapes_was_run)
|
||||
{
|
||||
m_model_proto->mutable_graph()->clear_value_info();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<ONNX_NAMESPACE::ModelProto> m_model_proto;
|
||||
bool m_infer_shapes_was_run;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// \brief A helper class used to hold the ModelProto object as its field
|
||||
@ -198,9 +237,6 @@ struct onnx_editor::ONNXModelEditor::Impl
|
||||
onnx_common::parse_from_file(model_path))}
|
||||
{
|
||||
}
|
||||
|
||||
void infer_shapes() { ONNX_NAMESPACE::shape_inference::InferShapes(*m_model_proto.get()); }
|
||||
void remove_shape_inference_info() { m_model_proto->mutable_graph()->clear_value_info(); }
|
||||
};
|
||||
|
||||
onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path)
|
||||
@ -274,6 +310,58 @@ void onnx_editor::ONNXModelEditor::set_input_shapes(
|
||||
}
|
||||
}
|
||||
|
||||
PartialShape onnx_editor::ONNXModelEditor::get_tensor_shape(const std::string& tensor_name) const
|
||||
{
|
||||
const ValueInfoProto* value_info = nullptr;
|
||||
auto* onnx_graph = m_pimpl->m_model_proto->mutable_graph();
|
||||
InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto);
|
||||
if (const auto* input = find_graph_input(*onnx_graph, tensor_name))
|
||||
{
|
||||
value_info = input;
|
||||
}
|
||||
else if (const auto* output = find_graph_output(*onnx_graph, tensor_name))
|
||||
{
|
||||
value_info = output;
|
||||
}
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
onnx_shapes.infer_shapes();
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
NGRAPH_WARN << "Cannot replace existing shapes during get_tensor_shape";
|
||||
return PartialShape::dynamic();
|
||||
}
|
||||
auto node_it = std::find_if(std::begin(onnx_graph->value_info()),
|
||||
std::end(onnx_graph->value_info()),
|
||||
[&tensor_name](const ValueInfoProto& value_info) -> bool {
|
||||
return value_info.name() == tensor_name;
|
||||
});
|
||||
if (node_it != std::end(onnx_graph->value_info()))
|
||||
{
|
||||
value_info = &(*node_it);
|
||||
}
|
||||
}
|
||||
if (value_info != nullptr)
|
||||
{
|
||||
const auto& onnx_tensor_type = value_info->type().tensor_type();
|
||||
if (onnx_tensor_type.has_shape())
|
||||
{
|
||||
return onnx_common::to_ng_shape(onnx_tensor_type.shape());
|
||||
}
|
||||
else
|
||||
{
|
||||
return PartialShape::dynamic();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("The tensor: " + tensor_name + " was not found in the graph");
|
||||
}
|
||||
}
|
||||
|
||||
void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdge>& inputs,
|
||||
const std::vector<OutputEdge>& outputs)
|
||||
{
|
||||
@ -282,35 +370,78 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdg
|
||||
return;
|
||||
}
|
||||
|
||||
m_pimpl->infer_shapes();
|
||||
InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto);
|
||||
onnx_shapes.infer_shapes();
|
||||
|
||||
SubgraphExtractor editor{*(m_pimpl->m_model_proto->mutable_graph())};
|
||||
editor.add_new_inputs(inputs);
|
||||
editor.add_new_outputs(outputs);
|
||||
editor.extract_subgraph(outputs);
|
||||
|
||||
m_pimpl->remove_shape_inference_info();
|
||||
m_pimpl->m_is_mapper_updated = false;
|
||||
}
|
||||
|
||||
std::vector<std::string> onnx_editor::ONNXModelEditor::model_inputs() const
|
||||
{
|
||||
const auto& graph = m_pimpl->m_model_proto->graph();
|
||||
std::vector<std::string> inputs;
|
||||
inputs.reserve(graph.input_size() - graph.initializer_size());
|
||||
for (const auto& in : graph.input())
|
||||
{
|
||||
if (std::find_if(graph.initializer().begin(),
|
||||
graph.initializer().end(),
|
||||
[&in](const TensorProto& initializer) {
|
||||
return initializer.name() == in.name();
|
||||
}) == graph.initializer().end())
|
||||
{
|
||||
inputs.push_back(in.name());
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<std::string> inputs_and_initializers;
|
||||
inputs_and_initializers.reserve(graph.input_size() + graph.initializer_size());
|
||||
std::vector<std::string> onnx_editor::ONNXModelEditor::model_outputs() const
|
||||
{
|
||||
const auto& graph = m_pimpl->m_model_proto->graph();
|
||||
std::vector<std::string> outputs;
|
||||
outputs.reserve(graph.output_size());
|
||||
|
||||
std::transform(graph.input().begin(),
|
||||
graph.input().end(),
|
||||
std::back_inserter(inputs_and_initializers),
|
||||
std::transform(graph.output().begin(),
|
||||
graph.output().end(),
|
||||
std::back_inserter(outputs),
|
||||
extract_name<ONNX_NAMESPACE::ValueInfoProto>);
|
||||
|
||||
std::transform(graph.initializer().begin(),
|
||||
graph.initializer().end(),
|
||||
std::back_inserter(inputs_and_initializers),
|
||||
extract_name<ONNX_NAMESPACE::TensorProto>);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
return inputs_and_initializers;
|
||||
bool onnx_editor::ONNXModelEditor::is_input(const InputEdge& edge) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
const auto& port_name = m_pimpl->m_edge_mapper.get_input_port_name(edge);
|
||||
if (port_name.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& inputs = model_inputs();
|
||||
return std::count(std::begin(inputs), std::end(inputs), port_name) > 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_output(const OutputEdge& edge) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
const auto& port_name = m_pimpl->m_edge_mapper.get_output_port_name(edge);
|
||||
if (port_name.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& outputs = model_outputs();
|
||||
return std::count(std::begin(outputs), std::end(outputs), port_name) > 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::string onnx_editor::ONNXModelEditor::model_string() const
|
||||
@ -393,6 +524,12 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN
|
||||
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.is_correct_tensor_name(name);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> onnx_editor::ONNXModelEditor::decode()
|
||||
{
|
||||
return onnx_import::detail::decode_to_framework_nodes(m_pimpl->m_model_proto, m_model_path);
|
||||
|
551
ngraph/python/tests/test_frontend/test_frontend_onnx_editor.py
Normal file
551
ngraph/python/tests/test_frontend/test_frontend_onnx_editor.py
Normal file
@ -0,0 +1,551 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import onnx
|
||||
import pytest
|
||||
from onnx.helper import make_graph, make_model, make_tensor_value_info
|
||||
from ngraph import PartialShape
|
||||
from ngraph.frontend import FrontEndManager
|
||||
|
||||
|
||||
# in1 in2 in3
|
||||
# | | |
|
||||
# \ / |
|
||||
# +--------+ +------+
|
||||
# | Add | | Relu |
|
||||
# +--------+ +------+
|
||||
# <add_out> |
|
||||
# / \\ |
|
||||
# +--------+ +-----+ out3
|
||||
# | Split | | Mul |
|
||||
# |(split1)|..| |
|
||||
# +--------+ +-----+
|
||||
# / \ |
|
||||
# out1 out2 out4
|
||||
#
|
||||
def create_test_onnx_models():
|
||||
models = {}
|
||||
# Input model
|
||||
add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"])
|
||||
split = onnx.helper.make_node("Split", inputs=["add_out"],
|
||||
outputs=["out1", "out2"], name="split1", axis=0)
|
||||
relu = onnx.helper.make_node("Relu", inputs=["in3"], outputs=["out3"])
|
||||
mul = onnx.helper.make_node("Mul", inputs=["add_out", "add_out"], outputs=["out4"])
|
||||
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
|
||||
models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for extract_subgraph
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for extract_subgraph 2
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for extract_subgraph 3
|
||||
input_tensors = [
|
||||
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
|
||||
outputs=["out1", "out2"], name="split1", axis=0)
|
||||
graph = make_graph([expected_split], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_3.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for extract_subgraph 4
|
||||
input_tensors = [
|
||||
make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
|
||||
outputs=["out1", "out2"])
|
||||
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
|
||||
outputs=["out4"])
|
||||
graph = make_graph([expected_split, expected_mul], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for test_override_all_outputs
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, relu], "test_graph", input_tensors, output_tensors)
|
||||
models["test_override_all_outputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for test_override_all_outputs 2
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add, mul], "test_graph", input_tensors, output_tensors)
|
||||
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# Expected for test_override_all_inputs
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
expected_split = onnx.helper.make_node("Split", inputs=["out1/placeholder_port_0"],
|
||||
outputs=["out1", "out2"])
|
||||
expected_mul = onnx.helper.make_node("Mul", inputs=["out4/placeholder_port_0", "out4/placeholder_port_1"],
|
||||
outputs=["out4"])
|
||||
graph = make_graph([expected_split, relu, expected_mul], "test_graph", input_tensors, output_tensors)
|
||||
models["test_override_all_inputs.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
# test partial shape
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (8, 16)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (8, 16)),
|
||||
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (4, 6)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (4, 16)),
|
||||
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (4, 16)),
|
||||
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (4, 6)),
|
||||
make_tensor_value_info("out4", onnx.TensorProto.FLOAT, (8, 16)),
|
||||
]
|
||||
graph = make_graph([add, split, relu, mul], "test_graph", input_tensors, output_tensors)
|
||||
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
fem = FrontEndManager()
|
||||
test_models_names = []
|
||||
|
||||
|
||||
def setup_module():
|
||||
models = create_test_onnx_models()
|
||||
for name, model in models.items():
|
||||
onnx.save_model(model, name)
|
||||
test_models_names.append(name)
|
||||
|
||||
|
||||
def teardown_module():
|
||||
for name in test_models_names:
|
||||
os.remove(name)
|
||||
|
||||
|
||||
def skip_if_onnx_frontend_is_disabled():
|
||||
front_ends = fem.get_available_front_ends()
|
||||
if "onnx" not in front_ends:
|
||||
pytest.skip()
|
||||
|
||||
|
||||
# Function to compare ng Functions (ops names, types and shapes).
|
||||
# Note that the functions uses get_ordered_ops, so the topological order of ops should be also preserved.
|
||||
def compare_functions(current, expected): # noqa: C901 the function is too complex
|
||||
result = True
|
||||
msg = ""
|
||||
if current.get_friendly_name() != expected.get_friendly_name():
|
||||
result = False
|
||||
msg += "Friendly name of nG Functions not equal. "
|
||||
msg += f"Current: {current.get_friendly_name()}, expected: {expected.get_friendly_name()}. "
|
||||
|
||||
current_ops = current.get_ordered_ops()
|
||||
expected_ops = expected.get_ordered_ops()
|
||||
|
||||
if len(current_ops) != len(expected_ops):
|
||||
result = False
|
||||
msg += "Not equal number of ops. "
|
||||
msg += f"Current: {len(current_ops)}, expected: {len(expected_ops)}. "
|
||||
|
||||
for i in range(len(current_ops)):
|
||||
if (current_ops[i].get_friendly_name() != expected_ops[i].get_friendly_name()
|
||||
and current_ops[i].get_type_name() != "Constant"): # const have different names
|
||||
result = False
|
||||
msg += "Not equal op name. "
|
||||
msg += f"Current: {current_ops[i].get_friendly_name()}, "
|
||||
msg += f"expected: {expected_ops[i].get_friendly_name()}. "
|
||||
if current_ops[i].get_output_size() != expected_ops[i].get_output_size():
|
||||
result = False
|
||||
msg += f"Not equal output size of {current_ops[i].get_friendly_name()}. "
|
||||
for j in range(current_ops[i].get_output_size()):
|
||||
if current_ops[i].get_output_partial_shape(j) != expected_ops[i].get_output_partial_shape(j):
|
||||
result = False
|
||||
msg += f"Not equal op partial shapes of {current_ops[i].get_friendly_name()}. "
|
||||
msg += f"Current: {current_ops[i].get_partial_shape({j})}, "
|
||||
msg += f"expected: {expected_ops[i].get_partial_shape({j})}. "
|
||||
if current_ops[i].get_output_element_type(j) != expected_ops[i].get_output_element_type(j):
|
||||
result = False
|
||||
msg += f"Not equal output element type of {current_ops[i].get_friendly_name()}. "
|
||||
msg += f"Current: {current_ops[i].get_output_element_type(j)}, "
|
||||
msg += f"expected: {expected_ops[i].get_output_element_type(j)}. "
|
||||
|
||||
if not result:
|
||||
print(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_extract_subgraph():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="add_out").get_input_port(inputPortIndex=0) # in1
|
||||
place2 = model.get_place_by_tensor_name(tensorName="add_out").get_input_port(inputPortIndex=1) # in2
|
||||
place3 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
model.extract_subgraph(inputs=[place1, place2], outputs=[place3])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_2():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out3")
|
||||
model.extract_subgraph(inputs=[], outputs=[place1, place2])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph_2.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_3():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
place3 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
model.extract_subgraph(inputs=[place1], outputs=[place2, place3])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph_3.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_4():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
|
||||
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
place5 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
place6 = model.get_place_by_tensor_name(tensorName="out4")
|
||||
model.extract_subgraph(inputs=[place1, place2, place3], outputs=[place4, place5, place6])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph_4.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_outputs():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out3")
|
||||
model.override_all_outputs(outputs=[place1, place2])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_override_all_outputs.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_outputs_2():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="out4")
|
||||
model.override_all_outputs(outputs=[place1])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_override_all_outputs_2.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_inputs():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_operation_name_and_input_port(
|
||||
operationName="split1", inputPortIndex=0)
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="in3")
|
||||
model.override_all_inputs(inputs=[place1, place2, place3, place4])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_override_all_inputs.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_inputs_exceptions():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="in1")
|
||||
place2 = model.get_place_by_tensor_name(tensorName="in2")
|
||||
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="in3")
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
model.override_all_inputs(inputs=[place1, place2])
|
||||
assert "Unexpected number of inputs after override_all_inputs" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
model.override_all_inputs(inputs=[place3, place4])
|
||||
assert "Unexpected number of inputs after override_all_inputs" in str(e)
|
||||
|
||||
|
||||
def test_is_input_output():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="in2")
|
||||
assert place1.is_input()
|
||||
assert not place1.is_output()
|
||||
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
assert not place2.is_input()
|
||||
assert place2.is_output()
|
||||
|
||||
place3 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
assert not place3.is_input()
|
||||
assert not place3.is_output()
|
||||
|
||||
place4 = place1 = model.get_place_by_operation_name_and_input_port(
|
||||
operationName="split1", inputPortIndex=0)
|
||||
assert not place4.is_input()
|
||||
assert not place4.is_output()
|
||||
|
||||
|
||||
def test_set_partial_shape():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="in1")
|
||||
model.set_partial_shape(place1, PartialShape([8, 16]))
|
||||
place2 = model.get_place_by_tensor_name(tensorName="in2")
|
||||
model.set_partial_shape(place2, PartialShape([8, 16]))
|
||||
place3 = model.get_place_by_tensor_name(tensorName="in3")
|
||||
model.set_partial_shape(place3, PartialShape([4, 6]))
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_partial_shape.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_get_partial_shape():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="in1")
|
||||
assert model.get_partial_shape(place1) == PartialShape([2, 2])
|
||||
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
assert model.get_partial_shape(place2) == PartialShape([1, 2])
|
||||
|
||||
place3 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
assert model.get_partial_shape(place3) == PartialShape([2, 2])
|
||||
|
||||
place4 = model.get_place_by_tensor_name(tensorName="in3")
|
||||
model.set_partial_shape(place4, PartialShape([4, 6]))
|
||||
assert model.get_partial_shape(place4) == PartialShape([4, 6])
|
||||
assert model.get_partial_shape(place2) == PartialShape([1, 2])
|
||||
|
||||
|
||||
def test_get_inputs():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
in_names = [place.get_names()[0] for place in model.get_inputs()]
|
||||
assert in_names == ["in1", "in2", "in3"]
|
||||
|
||||
|
||||
def test_get_outputs():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
out_names = [place.get_names()[0] for place in model.get_outputs()]
|
||||
assert out_names == ["out1", "out2", "out3", "out4"]
|
||||
|
||||
|
||||
def test_is_equal():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="in1")
|
||||
assert place1.is_equal(place1)
|
||||
|
||||
place2 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
assert place2.is_equal(place2)
|
||||
|
||||
place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0)
|
||||
assert place3.is_equal(place4)
|
||||
|
||||
place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place6 = model.get_place_by_tensor_name(tensorName="out1").get_input_port(inputPortIndex=0)
|
||||
assert place5.is_equal(place6)
|
||||
|
||||
place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port()
|
||||
assert place7.is_equal(place7)
|
||||
|
||||
place8 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
assert place8.is_equal(place8)
|
||||
|
||||
assert not place1.is_equal(place2)
|
||||
assert not place6.is_equal(place7)
|
||||
assert not place8.is_equal(place2)
|
||||
|
||||
|
||||
def test_get_place_by_tensor_name():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework="onnx")
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
place1 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
assert place1
|
||||
|
||||
place2 = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
assert place2
|
||||
|
||||
place3 = model.get_place_by_tensor_name(tensorName="in1")
|
||||
assert place3
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
model.get_place_by_tensor_name(tensorName="0:add_out")
|
||||
assert "The tensor with name: 0:add_out does not exist in the graph" in str(e)
|
@ -1462,3 +1462,110 @@ NGRAPH_TEST(onnx_editor, cut_operator_with_no_schema)
|
||||
|
||||
EXPECT_TRUE(result.is_ok) << result.error_message;
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, is_model_input)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
EXPECT_TRUE(editor.is_input(InputEdge{0, 0}));
|
||||
const auto edge1 = editor.find_input_edge(EditorOutput{"add1"}, 1);
|
||||
EXPECT_TRUE(editor.is_input(edge1));
|
||||
|
||||
EXPECT_FALSE(editor.is_input(InputEdge{1, 2}));
|
||||
EXPECT_FALSE(editor.is_input(InputEdge{3, 0}));
|
||||
EXPECT_FALSE(editor.is_input(InputEdge{11, 0}));
|
||||
const auto edge2 = editor.find_input_edge(EditorOutput{"conv1"}, 2);
|
||||
EXPECT_FALSE(editor.is_input(edge2));
|
||||
EXPECT_FALSE(editor.is_input(InputEdge{2, 1})); // initializer is not treated as input
|
||||
const auto edge3 = editor.find_input_edge(EditorOutput{"conv1"}, EditorInput{"in4"});
|
||||
EXPECT_FALSE(editor.is_input(edge3));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, is_model_output)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
EXPECT_TRUE(editor.is_output(OutputEdge{4, 0}));
|
||||
EXPECT_TRUE(editor.is_output(OutputEdge{5, 1}));
|
||||
const auto edge1 = editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{"split2"});
|
||||
EXPECT_TRUE(editor.is_output(edge1));
|
||||
|
||||
EXPECT_FALSE(editor.is_output(OutputEdge{4, 1}));
|
||||
EXPECT_FALSE(editor.is_output(OutputEdge{0, 0}));
|
||||
EXPECT_FALSE(editor.is_output(OutputEdge{11, 0}));
|
||||
const auto edge2 = editor.find_output_edge("add2");
|
||||
EXPECT_FALSE(editor.is_output(edge2));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, model_inputs)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const auto inputs = editor.model_inputs();
|
||||
EXPECT_TRUE(inputs == (std::vector<std::string>{"in1", "in2", "in3"})); // in4 is initializer
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, model_output)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const auto outputs = editor.model_outputs();
|
||||
EXPECT_TRUE(outputs == (std::vector<std::string>{"mul1", "split2", "mul2"}));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, get_tensor_shape)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
EXPECT_EQ(editor.get_tensor_shape("mul2"), (PartialShape{1, 1, 2, 2}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("in1"), (PartialShape{2, 2}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("in2"), (PartialShape{}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("in3"), (PartialShape{1, 1, 2, 2}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("relu1"), (PartialShape{2, 2}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("add1"), (PartialShape{2, 2}));
|
||||
try
|
||||
{
|
||||
editor.get_tensor_shape("not_existed");
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(
|
||||
msg.find("The tensor: not_existed was not found in the graph") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, get_tensor_shape_after_modification)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
EXPECT_EQ(editor.get_tensor_shape("in3"), (PartialShape{1, 1, 2, 2}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("conv1"), (PartialShape{1, 1, 2, 2}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("mul2"), (PartialShape{1, 1, 2, 2}));
|
||||
editor.set_input_shapes({{"in3", (PartialShape{1, 1, 4, 4})}});
|
||||
EXPECT_EQ(editor.get_tensor_shape("conv1"), (PartialShape{1, 1, 4, 4}));
|
||||
EXPECT_EQ(editor.get_tensor_shape("in3"), (PartialShape{1, 1, 4, 4}));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, is_correct_tensor_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
EXPECT_TRUE(editor.is_correct_tensor_name("in1"));
|
||||
EXPECT_TRUE(editor.is_correct_tensor_name("relu1"));
|
||||
EXPECT_TRUE(editor.is_correct_tensor_name("split2"));
|
||||
EXPECT_TRUE(editor.is_correct_tensor_name("mul2"));
|
||||
EXPECT_TRUE(editor.is_correct_tensor_name("in4"));
|
||||
|
||||
EXPECT_FALSE(editor.is_correct_tensor_name("relu1_name"));
|
||||
EXPECT_FALSE(editor.is_correct_tensor_name("not_existed"));
|
||||
EXPECT_FALSE(editor.is_correct_tensor_name(""));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user