[ONNX] Fix "Model references undeclared parameters" exception (#15218)

In case when subgraph has implicit inputs from their indirect parent,
those inputs are not registered in direct parent.
So when subgraph node is created - it references input that is not
available in direct parent's scope.
In this patch, the proposed solution registers the input (the particular subgraph
references), in every (direct or indirect) that subgraph's parent.
This commit is contained in:
Mateusz Tabaka 2023-01-23 10:08:55 +01:00 committed by GitHub
parent ae15937c44
commit 931fd11eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 354 additions and 60 deletions

View File

@ -41,7 +41,7 @@ class ONNX_IMPORTER_API Node {
public:
Node() = delete;
// TODO: hide this ctor since it uses protobufs generated structures
Node(const ONNX_NAMESPACE::NodeProto& node_proto, const Graph& graph);
Node(const ONNX_NAMESPACE::NodeProto& node_proto, Graph* graph);
Node(Node&&) noexcept;
Node(const Node&);

View File

@ -10,7 +10,7 @@
namespace ngraph {
namespace onnx_import {
Subgraph Attribute::get_subgraph(const Graph* parent_graph) const {
Subgraph Attribute::get_subgraph(Graph* parent_graph) const {
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH) {
throw error::attribute::InvalidData{m_attribute_proto->type()};
}

View File

@ -265,7 +265,7 @@ public:
const std::string& get_string() const {
return m_attribute_proto->s();
}
Subgraph get_subgraph(const Graph* parent_graph) const;
Subgraph get_subgraph(Graph* parent_graph) const;
std::vector<Tensor> get_tensor_array() const {
std::vector<Tensor> ret;

View File

@ -220,7 +220,7 @@ void Graph::convert_to_ngraph_nodes() {
unsigned int completed = 0u;
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_model->get_graph().node()) {
const Node node{node_proto, *this};
const Node node{node_proto, this};
if (node.has_subgraphs()) {
const auto& subgraphs = node.get_subgraphs();
for (auto& kv : subgraphs) {
@ -309,7 +309,7 @@ void Graph::decode_to_framework_nodes() {
unsigned int completed = 0u;
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_model->get_graph().node()) {
const Node node{node_proto, *this};
const Node node{node_proto, this};
OutputVector ng_nodes{make_framework_nodes(node)};
set_friendly_names(node, ng_nodes);
// Iterate over the number of outputs for given node in graph.
@ -348,11 +348,11 @@ bool Graph::is_ng_node_in_cache(const std::string& name) const {
return m_cache->contains(name);
}
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const {
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) {
return m_cache->get_node(name);
}
OutputVector Graph::get_ng_outputs() const {
OutputVector Graph::get_ng_outputs() {
OutputVector results;
for (const auto& output : m_model->get_graph().output()) {
const auto& ng_output = get_ng_node_from_cache(output.name());
@ -457,7 +457,7 @@ const OpsetImports& Graph::get_opset_imports() const {
return m_model->get_opset_imports();
}
Subgraph::Subgraph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto, const Graph* parent_graph)
Subgraph::Subgraph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto, Graph* parent_graph)
: Graph(parent_graph->model_dir(),
model_proto,
common::make_unique<GraphCache>(),
@ -471,16 +471,20 @@ bool Subgraph::is_ng_node_in_cache(const std::string& name) const {
return m_parent_graph->is_ng_node_in_cache(name);
}
Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) const {
Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) {
if (m_cache->contains(name)) {
return m_cache->get_node(name);
}
return m_parent_graph->get_ng_node_from_cache(name);
}
OutputVector Subgraph::make_ng_nodes(const Node& onnx_node) {
replace_input_from_parent_scope_with_parameter(onnx_node);
return Graph::make_ng_nodes(onnx_node);
const auto from_parent_node = m_parent_graph->get_ng_node_from_cache(name);
if (op::is_constant(from_parent_node.get_node()))
return from_parent_node;
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
from_parent_node.get_partial_shape());
m_parameter_to_parent_node_map.insert({new_param, name});
m_cache->emplace_node(name, new_param);
m_parameters.push_back(new_param);
m_inputs_from_parent.push_back(name);
return new_param;
}
std::shared_ptr<Function> Subgraph::convert() {
@ -505,30 +509,6 @@ void Subgraph::infer_inputs_from_parent() {
}
}
OutputVector Subgraph::make_framework_nodes(const Node& onnx_node) {
replace_input_from_parent_scope_with_parameter(onnx_node);
return Graph::make_framework_nodes(onnx_node);
}
void Subgraph::replace_input_from_parent_scope_with_parameter(const Node& onnx_node) {
for (std::size_t i = 0; i < onnx_node.get_inputs_size(); ++i) {
const auto& in_name = onnx_node.input(static_cast<int>(i));
if (m_parent_graph->is_ng_node_in_cache(in_name) &&
std::find(m_inputs_from_parent.begin(), m_inputs_from_parent.end(), in_name) ==
m_inputs_from_parent.end()) {
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
if (op::is_constant(from_parent_node.get_node()))
continue;
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
from_parent_node.get_partial_shape());
m_parameter_to_parent_node_map.insert({new_param, in_name});
m_cache->emplace_node(in_name, new_param);
m_parameters.push_back(new_param);
m_inputs_from_parent.push_back(in_name);
}
}
}
} // namespace onnx_import
} // namespace ngraph

View File

@ -33,7 +33,7 @@ public:
Graph& operator=(Graph&&) = default;
std::shared_ptr<Function> decode();
virtual std::shared_ptr<Function> convert();
OutputVector get_ng_outputs() const;
OutputVector get_ng_outputs();
const std::string& get_name() const {
return m_model->get_graph().name();
}
@ -44,8 +44,8 @@ public:
return m_parameters;
}
virtual bool is_ng_node_in_cache(const std::string& name) const;
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
virtual OutputVector make_ng_nodes(const Node& onnx_node);
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name);
OutputVector make_ng_nodes(const Node& onnx_node);
const OpsetImports& get_opset_imports() const;
virtual ~Graph() = default;
@ -62,7 +62,7 @@ protected:
void set_friendly_names(const Node& onnx_node, const OutputVector& ng_subgraph_outputs) const;
protected:
virtual OutputVector make_framework_nodes(const Node& onnx_node);
OutputVector make_framework_nodes(const Node& onnx_node);
void decode_to_framework_nodes();
void convert_to_ngraph_nodes();
void remove_dangling_parameters();
@ -88,7 +88,7 @@ public:
///
/// \param[in] model The ONNX model object.
/// \param[in] parent_graph The reference to the parent graph.
Subgraph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model, const Graph* parent_graph);
Subgraph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model, Graph* parent_graph);
/// \brief Return nodes which are on the edge the subgraph and the parent graph.
/// \return Vector of edge nodes from parent scope.
@ -105,16 +105,11 @@ public:
Subgraph& operator=(Subgraph&&) = default;
bool is_ng_node_in_cache(const std::string& name) const override;
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const override;
OutputVector make_ng_nodes(const Node& onnx_node) override;
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) override;
void infer_inputs_from_parent();
private:
OutputVector make_framework_nodes(const Node& onnx_node) override;
/// \brief Checks if onnx_node has inputs from parent graph and replaces those inputs with Parameters
void replace_input_from_parent_scope_with_parameter(const Node& onnx_node);
const Graph* m_parent_graph;
Graph* m_parent_graph;
std::vector<std::string> m_inputs_from_parent;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::string> m_parameter_to_parent_node_map;
};

View File

@ -17,11 +17,11 @@ class Node::Impl {
public:
Impl() = delete;
Impl(const ONNX_NAMESPACE::NodeProto& node_proto, const Graph& graph)
Impl(const ONNX_NAMESPACE::NodeProto& node_proto, Graph* graph)
: m_node_proto{&node_proto},
m_name{node_proto.has_name() ? node_proto.name() : ""},
m_domain{get_node_domain(node_proto)},
m_graph{&graph},
m_graph{graph},
m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())} {
const auto& attributes = node_proto.attribute();
m_attributes.reserve(attributes.size());
@ -34,12 +34,12 @@ public:
}
Impl(const ONNX_NAMESPACE::NodeProto& node_proto,
const Graph& graph,
Graph* graph,
const std::unordered_map<std::string, std::shared_ptr<Subgraph>>& subgraphs)
: m_node_proto{&node_proto},
m_name{node_proto.has_name() ? node_proto.name() : ""},
m_domain{get_node_domain(node_proto)},
m_graph{&graph},
m_graph{graph},
m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())},
m_subgraphs(subgraphs) {
for (const auto& attr_proto : node_proto.attribute()) {
@ -87,7 +87,7 @@ public:
element::Type type) const;
const ONNX_NAMESPACE::NodeProto& node_proto() const;
const Graph& graph() const;
Graph* graph() const;
private:
Subgraph get_subgraph_from_attribute(const std::string& name) const;
@ -95,7 +95,7 @@ private:
const ONNX_NAMESPACE::NodeProto* m_node_proto;
std::string m_name;
std::string m_domain;
const Graph* m_graph;
Graph* m_graph;
std::vector<Attribute> m_attributes;
std::vector<std::reference_wrapper<const std::string>> m_output_names;
mutable std::string m_description;
@ -106,8 +106,8 @@ private:
const ONNX_NAMESPACE::NodeProto& Node::Impl::node_proto() const {
return *m_node_proto;
}
const Graph& Node::Impl::graph() const {
return *m_graph;
Graph* Node::Impl::graph() const {
return m_graph;
}
const std::vector<Attribute>& Node::Impl::attributes() const {
return m_attributes;
@ -287,7 +287,7 @@ std::shared_ptr<ov::op::v0::Constant> Node::Impl::get_attribute_as_constant(cons
return ov::op::v0::Constant::create(type != element::undefined ? type : element::i64, {value.size()}, value);
}
Node::Node(const ONNX_NAMESPACE::NodeProto& node_proto, const Graph& graph)
Node::Node(const ONNX_NAMESPACE::NodeProto& node_proto, Graph* graph)
: m_pimpl{new Impl{node_proto, graph}, [](Impl* impl) {
delete impl;
}} {}

View File

@ -0,0 +1,290 @@
ir_version: 6
producer_name: "nGraph ONNX Importer"
graph {
name: "if inside if inside loop"
node {
input: "trip_count"
input: ""
input: "out_init"
output: "out_final"
output: "out_values"
op_type: "Loop"
attribute {
name: "body"
type: GRAPH
g {
name: "loop body"
node {
output: "two"
op_type: "Constant"
attribute {
name: "value"
type: TENSOR
t {
dims: 1
data_type: 7
int64_data: 2
}
}
}
node {
input: "i"
input: "two"
output: "greater_than_two"
name: "Greater_1"
op_type: "Greater"
}
node {
output: "three"
op_type: "Constant"
attribute {
name: "value"
type: TENSOR
t {
dims: 1
data_type: 7
int64_data: 3
}
}
}
node {
input: "i"
input: "three"
output: "greater_than_three"
name: "Greater_2"
op_type: "Greater"
}
node {
input: "greater_than_two"
output: "if_1_out"
name: "If_1"
op_type: "If"
attribute {
name: "then_branch"
type: GRAPH
g {
name: "then_branch"
node {
input: "greater_than_three"
output: "if_2_out"
name: "If_2"
op_type: "If"
attribute {
name: "then_branch"
type: GRAPH
g {
name: "then_branch"
node {
input: "i"
input: "two"
output: "mul_1"
name: "Mul_1"
op_type: "Mul"
}
node {
input: "mul_1"
output: "cast_1"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "a_in"
input: "cast_1"
output: "mul_2"
name: "Mul_2"
op_type: "Mul"
}
output {
name: "mul_2"
type {
tensor_type {
elem_type: 1
}
}
}
}
}
attribute {
name: "else_branch"
type: GRAPH
g {
node {
input: "i"
input: "three"
output: "mul_3"
name: "Mul_3"
op_type: "Mul"
}
name: "else_branch"
node {
input: "mul_3"
output: "cast_2"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "a_in"
input: "cast_2"
output: "mul_4"
name: "Mul_4"
op_type: "Mul"
}
output {
name: "mul_4"
type {
tensor_type {
elem_type: 1
}
}
}
}
}
}
output {
name: "if_2_out"
type {
tensor_type {
elem_type: 1
}
}
}
}
}
attribute {
name: "else_branch"
type: GRAPH
g {
name: "else_branch"
node {
input: "a_in"
input: "a_in"
output: "add_5"
name: "Add_5"
op_type: "Add"
}
output {
name: "add_5"
type {
tensor_type {
elem_type: 1
}
}
}
}
}
}
node {
input: "cond"
output: "cond_out"
name: "cond_identity"
op_type: "Identity"
}
node {
input: "if_1_out"
output: "a_out"
name: "output_accumulator"
op_type: "Identity"
}
input {
name: "i"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "cond"
type {
tensor_type {
elem_type: 9
}
}
}
input {
name: "a_in"
type {
tensor_type {
elem_type: 1
}
}
}
output {
name: "cond_out"
type {
tensor_type {
elem_type: 9
}
}
}
output {
name: "if_1_out"
type {
tensor_type {
elem_type: 1
}
}
}
output {
name: "a_out"
type {
tensor_type {
elem_type: 1
}
}
}
}
}
}
initializer {
dims: 1
data_type: 7
int64_data: 5
name: "trip_count"
}
input {
name: "out_init"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "out_final"
type {
tensor_type {
elem_type: 1
}
}
}
output {
name: "out_values"
type {
tensor_type {
elem_type: 1
}
}
}
}
opset_import {
version: 11
}

View File

@ -776,6 +776,34 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_if_with_only_indentity_in_else_branch) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_inside_if_inside_loop) {
/*
for (i = 0; i < 5; i++) {
if (i > 2) {
if (i > 3)
out *= float(i * 2)
else
out *= float(i * 3)
} else {
out += out
}
}
*/
const auto function =
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/controlflow/if_inside_if_inside_loop.onnx"));
auto test_case = test::TestCase(function, s_device);
// out_init
test_case.add_input<float>({1.f});
test_case.add_expected_output<float>(Shape{1}, {576});
test_case.add_expected_output<float>(Shape{5, 1}, {2, 4, 8, 72, 576});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_dynamic_inputs) {
/*
if (condition) {

View File

@ -32,6 +32,7 @@ INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
# ONNX evaluate method not implemented for If operator
INTERPRETER.onnx_if_inside_if
INTERPRETER.onnx_if_inside_loop
INTERPRETER.onnx_if_inside_if_inside_loop
# Activation function hardsigmoid unsupported
onnx_model_gru_fwd_activations_relu_hardsigmoid