Extend ONNX Importer for operation "If" (#7319)
This commit is contained in:
@@ -66,8 +66,8 @@ public:
|
||||
|
||||
bool has_attribute(const std::string& name) const;
|
||||
|
||||
bool has_subgraph() const;
|
||||
std::shared_ptr<Subgraph> get_subgraph() const;
|
||||
bool has_subgraphs() const;
|
||||
const std::unordered_map<std::string, std::shared_ptr<Subgraph>>& get_subgraphs() const;
|
||||
|
||||
template <typename T>
|
||||
T get_attribute_value(const std::string& name, T default_value) const;
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
Subgraph Attribute::get_subgraph(const Graph& parent_graph) const {
|
||||
Subgraph Attribute::get_subgraph(const Graph* parent_graph) const {
|
||||
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH) {
|
||||
throw error::attribute::InvalidData{m_attribute_proto->type()};
|
||||
}
|
||||
@@ -21,7 +21,7 @@ Subgraph Attribute::get_subgraph(const Graph& parent_graph) const {
|
||||
model_proto->mutable_graph()->CopyFrom(graph);
|
||||
|
||||
// set opset version and domain from the parent graph
|
||||
model_proto->mutable_opset_import()->CopyFrom(parent_graph.get_opset_imports());
|
||||
model_proto->mutable_opset_import()->CopyFrom(parent_graph->get_opset_imports());
|
||||
return Subgraph{model_proto, parent_graph};
|
||||
}
|
||||
|
||||
|
||||
@@ -301,7 +301,7 @@ public:
|
||||
const std::string& get_string() const {
|
||||
return m_attribute_proto->s();
|
||||
}
|
||||
Subgraph get_subgraph(const Graph& parent_graph) const;
|
||||
Subgraph get_subgraph(const Graph* parent_graph) const;
|
||||
|
||||
std::vector<Tensor> get_tensor_array() const {
|
||||
return {std::begin(m_attribute_proto->tensors()), std::end(m_attribute_proto->tensors())};
|
||||
|
||||
@@ -164,9 +164,12 @@ void Graph::convert_to_ngraph_nodes() {
|
||||
// Process ONNX graph nodes, convert to nGraph nodes
|
||||
for (const auto& node_proto : m_model->get_graph().node()) {
|
||||
const Node node{node_proto, *this};
|
||||
if (node.has_subgraph()) {
|
||||
auto subgraph = node.get_subgraph();
|
||||
auto body_func = subgraph->convert();
|
||||
if (node.has_subgraphs()) {
|
||||
const auto& subgraphs = node.get_subgraphs();
|
||||
for (auto& kv : subgraphs) {
|
||||
auto& subgraph = kv.second;
|
||||
subgraph->convert();
|
||||
}
|
||||
}
|
||||
OutputVector ng_nodes{make_ng_nodes(node)};
|
||||
}
|
||||
@@ -203,12 +206,21 @@ void Graph::decode_to_framework_nodes() {
|
||||
for (const auto& node_proto : m_model->get_graph().node()) {
|
||||
const Node node{node_proto, *this};
|
||||
std::shared_ptr<frontend::ONNXFrameworkNode> framework_node;
|
||||
if (node.has_subgraph()) {
|
||||
auto subgraph = node.get_subgraph();
|
||||
auto body_func = subgraph->decode();
|
||||
if (node.has_subgraphs()) {
|
||||
const auto& subgraphs = node.get_subgraphs();
|
||||
auto inputs = node.get_ng_inputs();
|
||||
for (const auto& input : subgraph->get_inputs_from_parent())
|
||||
inputs.push_back(input);
|
||||
for (const auto& kv : subgraphs) {
|
||||
auto& subgraph = kv.second;
|
||||
subgraph->decode();
|
||||
for (const auto& input : subgraph->get_inputs_from_parent()) {
|
||||
const auto& name = input.get_node()->get_friendly_name();
|
||||
if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output<ngraph::Node>& n) -> bool {
|
||||
return name == n.get_node()->get_friendly_name();
|
||||
}) == inputs.end()) {
|
||||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
framework_node =
|
||||
std::make_shared<ngraph::frontend::ONNXSubgraphFrameworkNode>(shared_from_this(), node, inputs);
|
||||
} else {
|
||||
@@ -239,8 +251,8 @@ std::shared_ptr<Function> Graph::decode() {
|
||||
return create_function();
|
||||
}
|
||||
|
||||
const GraphCache& Graph::get_graph_cache() const {
|
||||
return *m_cache.get();
|
||||
bool Graph::is_ng_node_in_cache(const std::string& name) const {
|
||||
return m_cache->contains(name);
|
||||
}
|
||||
|
||||
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const {
|
||||
@@ -309,15 +321,34 @@ const OpsetImports& Graph::get_opset_imports() const {
|
||||
return m_model->get_opset_imports();
|
||||
}
|
||||
|
||||
Subgraph::Subgraph(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto, const Graph& parent_graph)
|
||||
Subgraph::Subgraph(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model_proto, const Graph* parent_graph)
|
||||
: Graph(model_proto, common::make_unique<GraphCache>()),
|
||||
m_parent_graph_cache(&parent_graph.get_graph_cache()) {}
|
||||
m_parent_graph(parent_graph) {}
|
||||
|
||||
bool Subgraph::is_ng_node_in_cache(const std::string& name) const {
|
||||
if (m_cache->contains(name)) {
|
||||
return true;
|
||||
}
|
||||
return m_parent_graph->is_ng_node_in_cache(name);
|
||||
}
|
||||
|
||||
Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) const {
|
||||
if (m_cache->contains(name)) {
|
||||
return m_cache->get_node(name);
|
||||
}
|
||||
return m_parent_graph_cache->get_node(name);
|
||||
return m_parent_graph->get_ng_node_from_cache(name);
|
||||
}
|
||||
|
||||
void Subgraph::replace_input_from_parent_scope_with_parameter(const std::string& in_name,
|
||||
const Output<ngraph::Node>& from_parent_node,
|
||||
Input<ngraph::Node>&& node_to_replace_input) {
|
||||
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
|
||||
from_parent_node.get_partial_shape());
|
||||
node_to_replace_input.replace_source_output(new_param);
|
||||
m_parameter_to_parent_node_map.insert({new_param, in_name});
|
||||
m_cache->emplace_node(in_name, new_param);
|
||||
m_parameters.push_back(new_param);
|
||||
m_inputs_from_parent.push_back(in_name);
|
||||
}
|
||||
|
||||
void Subgraph::find_inputs_from_parent() {
|
||||
@@ -326,28 +357,47 @@ void Subgraph::find_inputs_from_parent() {
|
||||
for (const auto& node_proto : m_model->get_graph().node()) {
|
||||
int input_index = 0;
|
||||
for (const auto& in_name : node_proto.input()) {
|
||||
if (m_parent_graph_cache->contains(in_name)) {
|
||||
const auto& from_parent_node = m_parent_graph_cache->get_node(in_name);
|
||||
if (m_parent_graph->is_ng_node_in_cache(in_name)) {
|
||||
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
|
||||
// constants are skipped
|
||||
if (!ngraph::is_type<ngraph::op::Constant>(from_parent_node.get_node_shared_ptr())) {
|
||||
for (const auto& out_name : node_proto.output()) {
|
||||
if (m_cache->contains(out_name)) {
|
||||
auto out_node_to_replace_input = m_cache->get_node(out_name);
|
||||
auto new_param =
|
||||
std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
|
||||
from_parent_node.get_partial_shape());
|
||||
// replace input from parent scope with parameter
|
||||
out_node_to_replace_input.get_node()->input(input_index).replace_source_output(new_param);
|
||||
m_parameter_to_parent_node_map.insert({new_param, in_name});
|
||||
m_cache->emplace_node(in_name, new_param);
|
||||
m_parameters.push_back(new_param);
|
||||
m_inputs_from_parent.push_back(in_name);
|
||||
auto node_to_replace_input = m_cache->get_node(out_name);
|
||||
replace_input_from_parent_scope_with_parameter(
|
||||
in_name,
|
||||
from_parent_node,
|
||||
node_to_replace_input.get_node()->input(input_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
++input_index;
|
||||
}
|
||||
// Nodes with subgraphs (like Loop or If) can have implicit inputs (so their subgraphs depend on nodes from
|
||||
// parent) Those implicit inputs are not present in `node_proto.input()` list so to get them, we need to fetch
|
||||
// node's nGraph representation and then we can match those inputs with parent nodes
|
||||
for (const auto& out_name : node_proto.output()) {
|
||||
if (m_cache->contains(out_name)) {
|
||||
auto node_to_replace_input = m_cache->get_node(out_name).get_node();
|
||||
if (!dynamic_cast<op::util::MultiSubGraphOp*>(node_to_replace_input))
|
||||
continue;
|
||||
auto inputs = node_to_replace_input->input_values();
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
const auto& input = inputs.at(i);
|
||||
auto input_node = input.get_node();
|
||||
if (op::is_constant(input_node))
|
||||
continue;
|
||||
const auto& in_name = input_node->get_friendly_name();
|
||||
if (m_parent_graph->is_ng_node_in_cache(in_name)) {
|
||||
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
|
||||
replace_input_from_parent_scope_with_parameter(in_name,
|
||||
from_parent_node,
|
||||
node_to_replace_input->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,19 +415,20 @@ void Subgraph::decode_to_framework_nodes() {
|
||||
const std::vector<Output<ngraph::Node>> Subgraph::get_inputs_from_parent() const {
|
||||
OutputVector result;
|
||||
for (const auto& name : m_inputs_from_parent) {
|
||||
result.push_back(m_parent_graph_cache->get_node(name));
|
||||
result.push_back(m_parent_graph->get_ng_node_from_cache(name));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void Subgraph::infer_inputs_from_parent() {
|
||||
for (auto& it : m_parameter_to_parent_node_map) {
|
||||
const auto& node = m_parent_graph_cache->get_node(it.second);
|
||||
const auto& node = m_parent_graph->get_ng_node_from_cache(it.second);
|
||||
auto& parameter = it.first;
|
||||
parameter->set_element_type(node.get_element_type());
|
||||
parameter->set_partial_shape(node.get_partial_shape());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
||||
|
||||
@@ -34,10 +34,10 @@ public:
|
||||
const std::string& get_name() const {
|
||||
return m_model->get_graph().name();
|
||||
}
|
||||
const GraphCache& get_graph_cache() const;
|
||||
const ParameterVector& get_ng_parameters() const {
|
||||
return m_parameters;
|
||||
}
|
||||
virtual bool is_ng_node_in_cache(const std::string& name) const;
|
||||
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
|
||||
OutputVector make_ng_nodes(const Node& onnx_node) const;
|
||||
const OpsetImports& get_opset_imports() const;
|
||||
@@ -71,7 +71,7 @@ public:
|
||||
///
|
||||
/// \param[in] model The ONNX model object.
|
||||
/// \param[in] parent_graph The reference to the parent graph.
|
||||
Subgraph(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model, const Graph& parent_graph);
|
||||
Subgraph(std::shared_ptr<ONNX_NAMESPACE::ModelProto> model, const Graph* parent_graph);
|
||||
|
||||
/// \brief Return nodes which are on the edge the subgraph and the parent graph.
|
||||
/// \return Vector of edge nodes from parent scope.
|
||||
@@ -87,14 +87,23 @@ public:
|
||||
Subgraph& operator=(const Subgraph&) = delete;
|
||||
Subgraph& operator=(Subgraph&&) = default;
|
||||
|
||||
bool is_ng_node_in_cache(const std::string& name) const override;
|
||||
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const override;
|
||||
void infer_inputs_from_parent();
|
||||
|
||||
private:
|
||||
void decode_to_framework_nodes() override;
|
||||
void find_inputs_from_parent();
|
||||
/// \brief Replaces current node's input with Parameter if that input comes from parent graph scope
|
||||
///
|
||||
/// \param[in] in_name input node name
|
||||
/// \param[in] from_parent_node nGraph node from parent scope
|
||||
/// \param[in] node_to_replace_input nGraph input node to be replaced
|
||||
void replace_input_from_parent_scope_with_parameter(const std::string& in_name,
|
||||
const Output<ngraph::Node>& from_parent_node,
|
||||
Input<ngraph::Node>&& node_to_replace_input);
|
||||
|
||||
const GraphCache* m_parent_graph_cache;
|
||||
const Graph* m_parent_graph;
|
||||
std::vector<std::string> m_inputs_from_parent;
|
||||
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::string> m_parameter_to_parent_node_map;
|
||||
};
|
||||
|
||||
@@ -24,24 +24,22 @@ public:
|
||||
m_graph{&graph},
|
||||
m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())},
|
||||
m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())} {
|
||||
const auto it = std::find_if(std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
|
||||
return attribute.is_graph();
|
||||
});
|
||||
m_has_subgraph = it != std::end(m_attributes);
|
||||
if (m_has_subgraph) {
|
||||
m_subgraph = std::make_shared<Subgraph>(it->get_subgraph(*m_graph));
|
||||
for (const auto& attribute : m_attributes) {
|
||||
if (attribute.is_graph())
|
||||
m_subgraphs.insert({attribute.get_name(), std::make_shared<Subgraph>(attribute.get_subgraph(m_graph))});
|
||||
}
|
||||
}
|
||||
|
||||
Impl(const ONNX_NAMESPACE::NodeProto& node_proto, const Graph& graph, std::shared_ptr<Subgraph> subgraph)
|
||||
Impl(const ONNX_NAMESPACE::NodeProto& node_proto,
|
||||
const Graph& graph,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Subgraph>>& subgraphs)
|
||||
: m_node_proto{&node_proto},
|
||||
m_name{node_proto.has_name() ? node_proto.name() : ""},
|
||||
m_domain{get_node_domain(node_proto)},
|
||||
m_graph{&graph},
|
||||
m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())},
|
||||
m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())},
|
||||
m_has_subgraph(subgraph != nullptr),
|
||||
m_subgraph(subgraph) {}
|
||||
m_subgraphs(subgraphs) {}
|
||||
|
||||
const std::vector<Attribute>& attributes() const;
|
||||
OutputVector get_ng_inputs() const;
|
||||
@@ -57,8 +55,8 @@ public:
|
||||
|
||||
bool has_attribute(const std::string& name) const;
|
||||
|
||||
bool has_subgraph() const;
|
||||
std::shared_ptr<Subgraph> get_subgraph() const;
|
||||
bool has_subgraphs() const;
|
||||
const std::unordered_map<std::string, std::shared_ptr<Subgraph>>& get_subgraphs() const;
|
||||
|
||||
template <typename T>
|
||||
T get_attribute_value(const std::string& name, T default_value) const;
|
||||
@@ -80,8 +78,7 @@ private:
|
||||
std::vector<std::reference_wrapper<const std::string>> m_output_names;
|
||||
mutable std::string m_description;
|
||||
|
||||
bool m_has_subgraph;
|
||||
std::shared_ptr<Subgraph> m_subgraph;
|
||||
std::unordered_map<std::string, std::shared_ptr<Subgraph>> m_subgraphs;
|
||||
};
|
||||
|
||||
const ONNX_NAMESPACE::NodeProto& Node::Impl::node_proto() const {
|
||||
@@ -127,15 +124,15 @@ Subgraph Node::Impl::get_subgraph_from_attribute(const std::string& name) const
|
||||
if (it == std::end(m_attributes)) {
|
||||
throw error::node::UnknownAttribute{this->name(), name};
|
||||
}
|
||||
return it->get_subgraph(*m_graph);
|
||||
return it->get_subgraph(m_graph);
|
||||
}
|
||||
|
||||
bool Node::Impl::has_subgraph() const {
|
||||
return m_has_subgraph;
|
||||
bool Node::Impl::has_subgraphs() const {
|
||||
return m_subgraphs.size() > 0;
|
||||
}
|
||||
|
||||
std::shared_ptr<Subgraph> Node::Impl::get_subgraph() const {
|
||||
return m_subgraph;
|
||||
const std::unordered_map<std::string, std::shared_ptr<Subgraph>>& Node::Impl::get_subgraphs() const {
|
||||
return m_subgraphs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -198,7 +195,7 @@ Node::Node(const ONNX_NAMESPACE::NodeProto& node_proto, const Graph& graph)
|
||||
Node::Node(Node&& other) noexcept : m_pimpl{std::move(other.m_pimpl)} {}
|
||||
|
||||
Node::Node(const Node& other)
|
||||
: m_pimpl{new Impl{other.m_pimpl->node_proto(), other.m_pimpl->graph(), other.get_subgraph()}, [](Impl* impl) {
|
||||
: m_pimpl{new Impl{other.m_pimpl->node_proto(), other.m_pimpl->graph(), other.get_subgraphs()}, [](Impl* impl) {
|
||||
delete impl;
|
||||
}} {}
|
||||
|
||||
@@ -231,12 +228,12 @@ bool Node::has_attribute(const std::string& name) const {
|
||||
return m_pimpl->has_attribute(name);
|
||||
}
|
||||
|
||||
bool Node::has_subgraph() const {
|
||||
return m_pimpl->has_subgraph();
|
||||
bool Node::has_subgraphs() const {
|
||||
return m_pimpl->has_subgraphs();
|
||||
}
|
||||
|
||||
std::shared_ptr<Subgraph> Node::get_subgraph() const {
|
||||
return m_pimpl->get_subgraph();
|
||||
const std::unordered_map<std::string, std::shared_ptr<Subgraph>>& Node::get_subgraphs() const {
|
||||
return m_pimpl->get_subgraphs();
|
||||
}
|
||||
|
||||
std::vector<std::string> Node::get_attribute_names() const {
|
||||
|
||||
@@ -86,14 +86,19 @@ public:
|
||||
: ONNXFrameworkNode(graph, node, inputs) {}
|
||||
|
||||
void infer_inputs_from_parent() {
|
||||
m_node.get_subgraph()->infer_inputs_from_parent();
|
||||
for (auto& subgraph : m_node.get_subgraphs())
|
||||
subgraph.second->infer_inputs_from_parent();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_subgraph_body() const {
|
||||
auto subgraph = m_node.get_subgraph();
|
||||
return std::make_shared<Function>(subgraph->get_ng_outputs(),
|
||||
subgraph->get_ng_parameters(),
|
||||
subgraph->get_name());
|
||||
std::vector<std::shared_ptr<Function>> get_subgraph_functions() const {
|
||||
std::vector<std::shared_ptr<Function>> ret;
|
||||
for (const auto& kv : m_node.get_subgraphs()) {
|
||||
auto& subgraph = kv.second;
|
||||
ret.push_back(std::make_shared<Function>(subgraph->get_ng_outputs(),
|
||||
subgraph->get_ng_parameters(),
|
||||
subgraph->get_name()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
69
ngraph/frontend/onnx/frontend/src/op/if.cpp
Normal file
69
ngraph/frontend/onnx/frontend/src/op/if.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "op/if.hpp"
|
||||
|
||||
#include "core/graph.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/opsets/opset8.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
OutputVector if_op(const Node& node) {
|
||||
const auto& ng_inputs = node.get_ng_inputs();
|
||||
NGRAPH_CHECK(ng_inputs.size() == 1, "If operator takes only one input");
|
||||
|
||||
const auto& subgraphs = node.get_subgraphs();
|
||||
NGRAPH_CHECK(subgraphs.count("then_branch") == 1, "Missing 'then_branch' attribute");
|
||||
auto then_subgraph = subgraphs.at("then_branch");
|
||||
const auto& then_params = then_subgraph->get_ng_parameters();
|
||||
auto then_branch =
|
||||
std::make_shared<Function>(then_subgraph->get_ng_outputs(), then_params, then_subgraph->get_name());
|
||||
NGRAPH_CHECK(subgraphs.count("else_branch") == 1, "Missing 'else_branch' attribute");
|
||||
auto else_subgraph = subgraphs.at("else_branch");
|
||||
const auto& else_params = else_subgraph->get_ng_parameters();
|
||||
auto else_branch =
|
||||
std::make_shared<Function>(else_subgraph->get_ng_outputs(), else_params, else_subgraph->get_name());
|
||||
|
||||
auto if_node = std::make_shared<ngraph::opset8::If>(ng_inputs.at(0));
|
||||
if_node->set_then_body(then_branch);
|
||||
if_node->set_else_body(else_branch);
|
||||
|
||||
const auto then_branch_inputs_from_parent = then_subgraph->get_inputs_from_parent();
|
||||
NGRAPH_CHECK(then_branch_inputs_from_parent.size() == then_params.size(),
|
||||
"Number of inputs to 'then_branch' is invalid. Expected " +
|
||||
std::to_string(then_branch_inputs_from_parent.size()) + ", actual " +
|
||||
std::to_string(then_params.size()));
|
||||
auto then_param = then_params.cbegin();
|
||||
for (const auto& from_parent : then_branch_inputs_from_parent) {
|
||||
if_node->set_input(from_parent, *then_param, nullptr);
|
||||
then_param++;
|
||||
}
|
||||
const auto else_branch_inputs_from_parent = else_subgraph->get_inputs_from_parent();
|
||||
NGRAPH_CHECK(else_branch_inputs_from_parent.size() == else_params.size(),
|
||||
"Number of inputs to 'else_branch' is invalid. Expected " +
|
||||
std::to_string(else_branch_inputs_from_parent.size()) + ", actual " +
|
||||
std::to_string(else_params.size()));
|
||||
auto else_param = else_params.cbegin();
|
||||
for (const auto& from_parent : else_branch_inputs_from_parent) {
|
||||
if_node->set_input(from_parent, nullptr, *else_param);
|
||||
else_param++;
|
||||
}
|
||||
NGRAPH_CHECK(then_branch->get_results().size() == else_branch->get_results().size(),
|
||||
"'then' and 'else' branches have to have the same number of outputs");
|
||||
auto else_result = else_branch->get_results().cbegin();
|
||||
for (const auto& then_result : then_branch->get_results()) {
|
||||
if_node->set_output(then_result, *else_result);
|
||||
else_result++;
|
||||
}
|
||||
if_node->validate_and_infer_types();
|
||||
|
||||
return if_node->outputs();
|
||||
}
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
25
ngraph/frontend/onnx/frontend/src/op/if.hpp
Normal file
25
ngraph/frontend/onnx/frontend/src/op/if.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
/// \brief Convert ONNX If operation to an nGraph node.
|
||||
///
|
||||
/// \param node The ONNX node object representing this operation.
|
||||
///
|
||||
/// \return The vector containing Ngraph nodes producing output of ONNX If
|
||||
/// operation.
|
||||
OutputVector if_op(const Node& node);
|
||||
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
@@ -59,7 +59,8 @@ OutputVector loop(const Node& node) {
|
||||
loop_carried_dependencies_map[i + 2] = loop_carried_dependencies[i].get_node()->get_friendly_name();
|
||||
}
|
||||
|
||||
auto body_graph = node.get_subgraph();
|
||||
const auto& subgraphs = node.get_subgraphs();
|
||||
auto body_graph = subgraphs.at("body");
|
||||
auto body_outputs = body_graph->get_ng_outputs();
|
||||
const auto& body_inputs = body_graph->get_ng_parameters();
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@
|
||||
#include "op/hard_swish.hpp"
|
||||
#include "op/hardmax.hpp"
|
||||
#include "op/identity.hpp"
|
||||
#include "op/if.hpp"
|
||||
#include "op/image_scaler.hpp"
|
||||
#include "op/instance_norm.hpp"
|
||||
#include "op/leaky_relu.hpp"
|
||||
@@ -338,6 +339,7 @@ OperatorsBridge::OperatorsBridge() {
|
||||
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
|
||||
REGISTER_OPERATOR("HardSwish", 1, hard_swish);
|
||||
REGISTER_OPERATOR("Identity", 1, identity);
|
||||
REGISTER_OPERATOR("If", 1, if_op);
|
||||
REGISTER_OPERATOR("ImageScaler", 1, image_scaler);
|
||||
REGISTER_OPERATOR("InstanceNormalization", 1, instance_norm);
|
||||
REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
|
||||
|
||||
@@ -55,7 +55,9 @@ void convert_decoded_function(std::shared_ptr<Function> function) {
|
||||
if (auto raw_node = std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node)) {
|
||||
if (auto subgraph_node = std::dynamic_pointer_cast<frontend::ONNXSubgraphFrameworkNode>(node)) {
|
||||
subgraph_node->infer_inputs_from_parent();
|
||||
convert_decoded_function(subgraph_node->get_subgraph_body());
|
||||
for (auto& function : subgraph_node->get_subgraph_functions()) {
|
||||
convert_decoded_function(function);
|
||||
}
|
||||
}
|
||||
auto ng_nodes = raw_node->get_ng_nodes();
|
||||
replace_node(raw_node, ng_nodes);
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "then_branch"
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
node {
|
||||
input: "y"
|
||||
output: "abs"
|
||||
name: "abs"
|
||||
op_type: "Abs"
|
||||
}
|
||||
name: "else_branch"
|
||||
output {
|
||||
name: "abs"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if1"
|
||||
output: "if2"
|
||||
output: "if3"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
output: "split1_1"
|
||||
output: "split1_2"
|
||||
output: "split1_3"
|
||||
name: "split1"
|
||||
op_type: "Split"
|
||||
attribute {
|
||||
name: "axis"
|
||||
type: INT
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "split1_1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "split1_2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "split1_3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
output: "split2_1"
|
||||
output: "split2_2"
|
||||
output: "split2_3"
|
||||
name: "split2"
|
||||
op_type: "Split"
|
||||
attribute {
|
||||
name: "axis"
|
||||
type: INT
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "perm"
|
||||
name: "perm"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
type: TENSOR
|
||||
t {
|
||||
dims: 2
|
||||
data_type: 6
|
||||
int32_data: 1
|
||||
int32_data: 0
|
||||
name: "perm"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "split2_1"
|
||||
input: "perm"
|
||||
output: "transpose_1"
|
||||
name: "transpose_1"
|
||||
op_type: "Transpose"
|
||||
}
|
||||
node {
|
||||
input: "split2_2"
|
||||
input: "perm"
|
||||
output: "transpose_2"
|
||||
name: "transpose_2"
|
||||
op_type: "Transpose"
|
||||
}
|
||||
node {
|
||||
input: "split2_3"
|
||||
input: "perm"
|
||||
output: "transpose_3"
|
||||
name: "transpose_3"
|
||||
op_type: "Transpose"
|
||||
}
|
||||
output {
|
||||
name: "transpose_1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "transpose_2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "transpose_3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "then_branch"
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "mul"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
name: "else_branch"
|
||||
output {
|
||||
name: "mul"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "then_branch"
|
||||
node {
|
||||
output: "const1"
|
||||
name: "const1"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
type: TENSOR
|
||||
t {
|
||||
dims: 2
|
||||
dims: 4
|
||||
data_type: 1
|
||||
float_data: 0
|
||||
float_data: 1
|
||||
float_data: 2
|
||||
float_data: 3
|
||||
float_data: 4
|
||||
float_data: 5
|
||||
float_data: 6
|
||||
float_data: 7
|
||||
name: "const_tensor"
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "const1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
node {
|
||||
output: "const2"
|
||||
name: "const2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 2
|
||||
dims: 4
|
||||
data_type: 1
|
||||
float_data: 0
|
||||
float_data: 5
|
||||
float_data: 10
|
||||
float_data: 15
|
||||
float_data: 20
|
||||
float_data: 25
|
||||
float_data: 20
|
||||
float_data: 15
|
||||
name: "const_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
name: "else_branch"
|
||||
output {
|
||||
name: "const2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "then_branch"
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "mul"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
name: "else_branch"
|
||||
output {
|
||||
name: "mul"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
234
ngraph/test/models/onnx/controlflow/if_inside_if.prototxt
Normal file
234
ngraph/test/models/onnx/controlflow/if_inside_if.prototxt
Normal file
@@ -0,0 +1,234 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
name: "then_branch"
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "greater"
|
||||
name: "greater"
|
||||
op_type: "Greater"
|
||||
}
|
||||
node {
|
||||
input: "greater"
|
||||
output: "cast_to_int"
|
||||
name: "cast_to_int"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
type: INT
|
||||
i: 6
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "cast_to_int"
|
||||
output: "reduce_max"
|
||||
name: "reduce_max"
|
||||
op_type: "ReduceMax"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
type: INT
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "reduce_max"
|
||||
output: "cast_to_bool"
|
||||
name: "cast_to_bool"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
type: INT
|
||||
i: 9
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "cast_to_bool"
|
||||
output: "if_inside"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "then_branch_inside"
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "mul"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
output {
|
||||
name: "mul"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "else_branch_inside"
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if_inside"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "else_branch"
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "sub"
|
||||
name: "sub"
|
||||
op_type: "Sub"
|
||||
}
|
||||
output {
|
||||
name: "sub"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
209
ngraph/test/models/onnx/controlflow/if_inside_loop.prototxt
Normal file
209
ngraph/test/models/onnx/controlflow/if_inside_loop.prototxt
Normal file
@@ -0,0 +1,209 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if inside loop"
|
||||
node {
|
||||
input: "scale"
|
||||
input: "scale"
|
||||
name: "mul_node"
|
||||
op_type: "Mul"
|
||||
output: "b"
|
||||
}
|
||||
node {
|
||||
input: "trip_count"
|
||||
input: ""
|
||||
input: "a_init"
|
||||
output: "a_final"
|
||||
output: "a_values"
|
||||
op_type: "Loop"
|
||||
attribute {
|
||||
name: "body"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "loop body"
|
||||
node {
|
||||
output: "zero"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
type: TENSOR
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "i"
|
||||
input: "zero"
|
||||
output: "first_iter"
|
||||
name: "equal"
|
||||
op_type: "Equal"
|
||||
}
|
||||
node {
|
||||
input: "first_iter"
|
||||
output: "current_a"
|
||||
name: "current_a"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "then_branch"
|
||||
node {
|
||||
input: "b"
|
||||
input: "a_in"
|
||||
output: "a_out"
|
||||
name: "loop_body_add"
|
||||
op_type: "Add"
|
||||
}
|
||||
output {
|
||||
name: "a_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
type: GRAPH
|
||||
g {
|
||||
name: "else_branch"
|
||||
node {
|
||||
input: "b"
|
||||
input: "a_in"
|
||||
output: "a_out"
|
||||
name: "loop_body_mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
output {
|
||||
name: "a_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "cond"
|
||||
output: "cond_out"
|
||||
name: "cond_identity"
|
||||
op_type: "Identity"
|
||||
}
|
||||
node {
|
||||
input: "current_a"
|
||||
output: "a_out"
|
||||
name: "output_accumulator"
|
||||
op_type: "Identity"
|
||||
}
|
||||
input {
|
||||
name: "i"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "cond"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "a_in"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cond_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "current_a"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
name: "trip_count"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 1
|
||||
float_data: 2
|
||||
name: "scale"
|
||||
}
|
||||
input {
|
||||
name: "a_init"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_final"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "then_branch"
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "mul"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
name: "else_branch"
|
||||
output {
|
||||
name: "mul"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "if graph"
|
||||
node {
|
||||
input: "condition"
|
||||
output: "if"
|
||||
name: "if"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
name: "then_branch"
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
name: "else_branch"
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "mul"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "add"
|
||||
name: "add"
|
||||
op_type: "Add"
|
||||
}
|
||||
output {
|
||||
name: "add"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "mul"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "condition"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "if"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 10
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -31,7 +31,7 @@ using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||
// }
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add.onnx"));
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add.onnx"));
|
||||
|
||||
// Shape inference tests
|
||||
const auto& parameters = function->get_parameters();
|
||||
@@ -65,7 +65,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add) {
|
||||
// }
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_cond) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_no_identity_termination_cond.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_no_identity_termination_cond.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// termination condition
|
||||
@@ -80,7 +80,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_co
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_max_int) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_trip_count_max_int.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_trip_count_max_int.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// termination condition
|
||||
@@ -95,7 +95,8 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_max_int) {
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_cond_static_shapes) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_no_identity_termination_cond_static_shapes.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/controlflow/loop_2d_add_no_identity_termination_cond_static_shapes.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// termination condition
|
||||
@@ -111,7 +112,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_co
|
||||
// input ("", cond) // Note this is analogous to a while loop
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_cond_false) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_no_identity_termination_cond_false.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_no_identity_termination_cond_false.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -129,7 +130,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_co
|
||||
// }
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_const_no_identity_termination_cond) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_const_no_identity_termination_cond.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_const_no_identity_termination_cond.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// a_init
|
||||
@@ -143,7 +144,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_const_no_identity_terminat
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_const_no_identity_termination_cond_static_shapes) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/loop/loop_2d_add_const_no_identity_termination_cond_static_shapes.onnx"));
|
||||
"onnx/controlflow/loop_2d_add_const_no_identity_termination_cond_static_shapes.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -161,7 +162,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_const_no_identity_terminat
|
||||
// }
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_both_cond_and_trip_count_as_inputs) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_cond_and_trip_count_as_inputs.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_cond_and_trip_count_as_inputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// trip count
|
||||
@@ -180,7 +181,8 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_both_cond_and_trip_count_a
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_both_cond_and_trip_count_as_inputs_static_shapes) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_cond_and_trip_count_as_inputs_static_shapes.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/controlflow/loop_2d_add_cond_and_trip_count_as_inputs_static_shapes.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// trip count
|
||||
@@ -200,7 +202,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_both_cond_and_trip_count_a
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_initializer_from_parent_scope) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_initializer_from_parent_scope.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_initializer_from_parent_scope.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
@@ -214,7 +216,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_initializer_from_parent_s
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_node_from_parent_scope.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_node_from_parent_scope.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -228,7 +230,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope) {
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope_used_in_parent_and_in_body) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.onnx"));
|
||||
"onnx/controlflow/loop_add_node_from_parent_scope_used_in_parent_and_in_body.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -245,7 +247,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope_us
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_value_access_to_body_scope_exception) {
|
||||
try {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_incorrect_access_body_scope.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_incorrect_access_body_scope.onnx"));
|
||||
FAIL() << "Incorrect access to body scope not detected";
|
||||
} catch (const ngraph_error& e) {
|
||||
// patent graph should have no access to subgraph (body Loop) scope
|
||||
@@ -257,7 +259,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_value_access_to_body_scop
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_value_the_same_node_from_parent_and_subgraph) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_the_same_name.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_the_same_name.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -270,7 +272,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_value_the_same_node_from_
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_input_from_parent_graph) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_input_from_parent_graph.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_input_from_parent_graph.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -284,8 +286,8 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_input_from_parent_graph)
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_the_proper_opset_in_subgraph) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_mul_opset1.onnx"));
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_mul_opset1.onnx"));
|
||||
|
||||
const auto parent_ops = function->get_ops();
|
||||
const auto loop_node_it = std::find_if(parent_ops.begin(), parent_ops.end(), [](const std::shared_ptr<Node>& op) {
|
||||
@@ -303,7 +305,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_the_proper_opset_in_subgraph)
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_scalars) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_scalars_add.onnx"));
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_scalars_add.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -315,8 +317,8 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_scalars) {
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add_const_cond) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_const_cond.onnx"));
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_const_cond.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -329,7 +331,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add_const_cond) {
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_dynamic) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_trip_count_dynamic.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_trip_count_dynamic.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// trip count
|
||||
@@ -345,7 +347,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_dynamic) {
|
||||
// ~~~~~~~~SUBGRAPH TYPES INFERENCE:~~~~~~~~
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_infer_types) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/onnx_controlflow_loop_2d_infer_types.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/onnx_controlflow_loop_2d_infer_types.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// trip count
|
||||
@@ -363,7 +365,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_infer_types) {
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope_infer_types) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_add_node_from_parent_scope_infer_types.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_add_node_from_parent_scope_infer_types.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
@@ -380,8 +382,8 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope_in
|
||||
// ~~~~~~~~ADDITIONAL TESTS:~~~~~~~~
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_concat_values) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_concat_values.onnx"));
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_concat_values.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// trip_count
|
||||
@@ -405,7 +407,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_concat_values) {
|
||||
// }
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_and_cond_skipped_shape_inference) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_trip_count_and_cond_skipped.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_2d_add_trip_count_and_cond_skipped.onnx"));
|
||||
|
||||
const auto& results = function->get_results();
|
||||
EXPECT_EQ(results.size(), 2);
|
||||
@@ -421,7 +423,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_and_cond_skippe
|
||||
// infinitive loop execution
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_infinite) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_infinite.onnx"));
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_infinite.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// trip_count
|
||||
@@ -442,7 +444,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_infinite) {
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_no_variadic_inputs_and_outputs) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_no_variadic_inputs_and_outputs.onnx"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_no_variadic_inputs_and_outputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// trip_count
|
||||
@@ -457,7 +459,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_no_variadic_inputs_and_output
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_power) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_pow.onnx"));
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/loop_pow.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// trip_count
|
||||
@@ -471,3 +473,291 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_power) {
|
||||
test_case.add_expected_output<int64_t>(Shape{5}, {0, 1, 4, 9, 16});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_branches_with_same_inputs) {
|
||||
/*
|
||||
if (condition) {
|
||||
add(x, y)
|
||||
} else {
|
||||
mul(x, y)
|
||||
}
|
||||
*/
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_branches_with_same_inputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
std::vector<float> x(40, 2);
|
||||
std::vector<float> y(40);
|
||||
std::iota(y.begin(), y.end(), -20);
|
||||
|
||||
// condition
|
||||
test_case.add_input<bool>({true});
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
|
||||
std::vector<float> expected;
|
||||
std::transform(x.begin(), x.end(), y.begin(), std::back_inserter(expected), [](float i, float j) -> float {
|
||||
return i + j;
|
||||
});
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
|
||||
std::transform(x.begin(), x.end(), y.begin(), expected.begin(), [](float i, float j) -> float {
|
||||
return i * j;
|
||||
});
|
||||
test_case.add_input<bool>({false});
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_branches_with_different_inputs) {
|
||||
/*
|
||||
if (condition) {
|
||||
add(x, y)
|
||||
} else {
|
||||
abs(y)
|
||||
}
|
||||
*/
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_branches_with_different_inputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
std::vector<float> x(40, 2);
|
||||
std::vector<float> y(40);
|
||||
std::iota(y.begin(), y.end(), -20);
|
||||
|
||||
// condition
|
||||
test_case.add_input<bool>({true});
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
|
||||
std::vector<float> expected;
|
||||
std::transform(x.begin(), x.end(), y.begin(), std::back_inserter(expected), [](float i, float j) -> float {
|
||||
return i + j;
|
||||
});
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
|
||||
std::transform(y.begin(), y.end(), expected.begin(), [](float i) -> float {
|
||||
return std::fabs(i);
|
||||
});
|
||||
test_case.add_input<bool>({false});
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_branches_without_inputs) {
|
||||
/*
|
||||
if (condition) {
|
||||
return const tensor {0, 1, 2, 3, 4, 5, 6, 7}
|
||||
} else {
|
||||
return const tensor {0, 5, 10, 15, 20, 25, 20}
|
||||
}
|
||||
*/
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_branches_without_inputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
// condition
|
||||
test_case.add_input<bool>({true});
|
||||
|
||||
test_case.add_expected_output<float>({0, 1, 2, 3, 4, 5, 6, 7});
|
||||
test_case.run();
|
||||
|
||||
test_case.add_input<bool>({false});
|
||||
test_case.add_expected_output<float>({0, 5, 10, 15, 20, 25, 20, 15});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_inside_if) {
|
||||
/*
|
||||
if (condition) {
|
||||
if (any(x > y) {
|
||||
mul(x, y)
|
||||
} else {
|
||||
add(x, y)
|
||||
}
|
||||
} else {
|
||||
sub(x, y)
|
||||
}
|
||||
*/
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_inside_if.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
// case when condition == true and any(x > y)
|
||||
// expected value == x * y
|
||||
std::vector<float> x(40, 2);
|
||||
std::vector<float> y(40);
|
||||
std::iota(y.begin(), y.end(), -20);
|
||||
std::vector<float> expected;
|
||||
std::transform(x.begin(), x.end(), y.begin(), std::back_inserter(expected), [](float i, float j) -> float {
|
||||
return i * j;
|
||||
});
|
||||
test_case.add_input<bool>({true}); // condition
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
|
||||
// case when condition == true and all(x < y)
|
||||
// expected value == x + y
|
||||
std::iota(x.begin(), x.end(), -static_cast<float>(x.size()));
|
||||
std::iota(y.begin(), y.end(), 1);
|
||||
std::transform(x.begin(), x.end(), y.begin(), expected.begin(), [](float i, float j) -> float {
|
||||
return i + j;
|
||||
});
|
||||
test_case.add_input<bool>({true}); // condition
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
|
||||
// case when condition == false
|
||||
// expected value == x - y
|
||||
std::transform(x.begin(), x.end(), y.begin(), expected.begin(), [](float i, float j) -> float {
|
||||
return i - j;
|
||||
});
|
||||
test_case.add_input<bool>({false});
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_input<float>(y);
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_branches_with_multiple_outputs) {
|
||||
/*
|
||||
if (condition) {
|
||||
split(x, axis=0)
|
||||
} else {
|
||||
part1, part2, part3 = split(x, axis=1)
|
||||
transpose(part1), transpose(part2), transpose(part3)
|
||||
}
|
||||
*/
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_branches_with_multiple_outputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
// case when condition == true so split is along axis 0
|
||||
std::vector<float> x(36);
|
||||
std::iota(x.begin(), x.end(), 0);
|
||||
std::vector<float> expected1(12);
|
||||
std::iota(expected1.begin(), expected1.end(), 0);
|
||||
std::vector<float> expected2(12);
|
||||
std::iota(expected2.begin(), expected2.end(), 12);
|
||||
std::vector<float> expected3(12);
|
||||
std::iota(expected3.begin(), expected3.end(), 24);
|
||||
test_case.add_input<bool>({true}); // condition
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_expected_output<float>(expected1);
|
||||
test_case.add_expected_output<float>(expected2);
|
||||
test_case.add_expected_output<float>(expected3);
|
||||
test_case.run();
|
||||
|
||||
// case when condition == false so split is along axis 1
|
||||
test_case.add_input<bool>({false}); // condition
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_expected_output<float>({0, 6, 12, 18, 24, 30, 1, 7, 13, 19, 25, 31});
|
||||
test_case.add_expected_output<float>({2, 8, 14, 20, 26, 32, 3, 9, 15, 21, 27, 33});
|
||||
test_case.add_expected_output<float>({4, 10, 16, 22, 28, 34, 5, 11, 17, 23, 29, 35});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_inside_loop) {
|
||||
/*
|
||||
for (i = 0; i < 3; i++) {
|
||||
if (i == 0)
|
||||
a = a + b
|
||||
else
|
||||
a = a * b
|
||||
}
|
||||
*/
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_inside_loop.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {64.f, 64.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {4.f, 4.f, 16.f, 16.f, 64.f, 64.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_dynamic_inputs) {
|
||||
/*
|
||||
if (condition) {
|
||||
add(x, y)
|
||||
} else {
|
||||
mul(x, y)
|
||||
}
|
||||
*/
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_dynamic_inputs.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
std::vector<float> x(40, 2);
|
||||
std::vector<float> y(40);
|
||||
std::iota(y.begin(), y.end(), -20);
|
||||
std::vector<float> expected;
|
||||
std::transform(x.begin(), x.end(), y.begin(), std::back_inserter(expected), [](float i, float j) -> float {
|
||||
return i + j;
|
||||
});
|
||||
|
||||
test_case.add_input<bool>(Shape{}, {true}); // condition
|
||||
test_case.add_input<float>(Shape{4, 10}, x);
|
||||
test_case.add_input<float>(Shape{4, 10}, y);
|
||||
test_case.add_expected_output<float>(Shape{4, 10}, expected);
|
||||
test_case.run();
|
||||
|
||||
std::transform(x.begin(), x.end(), y.begin(), expected.begin(), [](float i, float j) -> float {
|
||||
return i * j;
|
||||
});
|
||||
test_case.add_input<bool>(Shape{}, {false});
|
||||
test_case.add_input<float>(Shape{4, 10}, x);
|
||||
test_case.add_input<float>(Shape{4, 10}, y);
|
||||
test_case.add_expected_output<float>(Shape{4, 10}, expected);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_negative_missing_branches) {
|
||||
try {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_missing_then_branch.onnx"));
|
||||
FAIL() << "Model import succeed, but it shouldn't";
|
||||
} catch (const ngraph_error& e) {
|
||||
EXPECT_HAS_SUBSTRING(e.what(), std::string("Missing 'then_branch' attribute"));
|
||||
} catch (...) {
|
||||
FAIL() << "Model import failed for unexpected reason";
|
||||
}
|
||||
|
||||
try {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_missing_else_branch.onnx"));
|
||||
FAIL() << "Model import succeed, but it shouldn't";
|
||||
} catch (const ngraph_error& e) {
|
||||
EXPECT_HAS_SUBSTRING(e.what(), std::string("Missing 'else_branch' attribute"));
|
||||
} catch (...) {
|
||||
FAIL() << "Model import failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_negative_mismatch_between_branches_output) {
|
||||
try {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_negative_mismatch_between_branches_output.onnx"));
|
||||
FAIL() << "Model import succeed, but it shouldn't";
|
||||
} catch (const ngraph_error& e) {
|
||||
EXPECT_HAS_SUBSTRING(e.what(),
|
||||
std::string("'then' and 'else' branches have to have the same number of outputs"));
|
||||
} catch (...) {
|
||||
FAIL() << "Model import failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1441,6 +1441,7 @@ onnx_controlflow_loop_2d_const_no_identity_termination_cond
|
||||
onnx_controlflow_loop_2d_both_cond_and_trip_count_as_inputs
|
||||
onnx_controlflow_loop_no_variadic_inputs_and_outputs
|
||||
onnx_controlflow_loop_power
|
||||
onnx_if_dynamic_inputs
|
||||
|
||||
# Input body shape is changed during Loop iterations
|
||||
# Exception is throw during Loop shape inference
|
||||
|
||||
@@ -63,8 +63,6 @@ xfail_issue_38708 = xfail_test(reason="RuntimeError: While validating ONNX node
|
||||
xfail_issue_38710 = xfail_test(reason="RuntimeError: data has zero dimension which is not allowed")
|
||||
xfail_issue_38713 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"ai.onnx.preview.training.Momentum")
|
||||
xfail_issue_43742 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"If")
|
||||
xfail_issue_45457 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v5::Loop"
|
||||
"Not constant termination condition body output is not supported")
|
||||
xfail_issue_38722 = xfail_test(reason="RuntimeError: While validating ONNX nodes MatMulInteger"
|
||||
@@ -151,3 +149,4 @@ skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with
|
||||
xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike")
|
||||
xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElement, OptionalGetElement")
|
||||
xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support")
|
||||
xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name")
|
||||
|
||||
@@ -33,7 +33,6 @@ from tests import (
|
||||
xfail_issue_39658,
|
||||
xfail_issue_39659,
|
||||
xfail_issue_39662,
|
||||
xfail_issue_43742,
|
||||
xfail_issue_44848,
|
||||
xfail_issue_44851,
|
||||
xfail_issue_44854,
|
||||
@@ -221,6 +220,7 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendSimpleModelTest.test_sequence_model4_cpu",
|
||||
"OnnxBackendSimpleModelTest.test_sequence_model2_cpu",
|
||||
"OnnxBackendNodeModelTest.test_identity_sequence_cpu",
|
||||
"OnnxBackendNodeModelTest.test_if_seq_cpu",
|
||||
),
|
||||
(
|
||||
xfail_issue_38701,
|
||||
@@ -376,11 +376,6 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_reduce_sum_keepdims_random_cpu",
|
||||
"OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_example_cpu",
|
||||
),
|
||||
(
|
||||
xfail_issue_43742,
|
||||
"OnnxBackendNodeModelTest.test_if_cpu",
|
||||
"OnnxBackendNodeModelTest.test_if_seq_cpu",
|
||||
),
|
||||
(
|
||||
xfail_issue_44848,
|
||||
"OnnxBackendNodeModelTest.test_range_float_type_positive_delta_cpu",
|
||||
|
||||
@@ -12,7 +12,6 @@ from tests.test_onnx.utils.model_importer import ModelImportRunner
|
||||
|
||||
from tests import (
|
||||
xfail_issue_38701,
|
||||
xfail_issue_43742,
|
||||
xfail_issue_45457,
|
||||
xfail_issue_37957,
|
||||
xfail_issue_38084,
|
||||
@@ -24,6 +23,7 @@ from tests import (
|
||||
xfail_issue_48145,
|
||||
xfail_issue_48190,
|
||||
xfail_issue_58676,
|
||||
xfail_issue_63643,
|
||||
xfail_issue_onnx_models_140)
|
||||
|
||||
MODELS_ROOT_DIR = tests.MODEL_ZOO_DIR
|
||||
@@ -145,11 +145,9 @@ if len(zoo_models) > 0:
|
||||
import_xfail_list = [
|
||||
# ONNX Model Zoo
|
||||
(xfail_issue_38701, "test_onnx_model_zoo_text_machine_comprehension_bidirectional_attention_flow_model_bidaf_9_bidaf_bidaf_cpu"),
|
||||
(xfail_issue_43742, "test_onnx_model_zoo_vision_object_detection_segmentation_ssd_mobilenetv1_model_ssd_mobilenet_v1_10_ssd_mobilenet_v1_ssd_mobilenet_v1_cpu"),
|
||||
(xfail_issue_38726, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_decoder_with_lm_head_12_t5_decoder_with_lm_head_cpu"),
|
||||
|
||||
# Model MSFT
|
||||
(xfail_issue_43742, "test_MSFT_opset10_mlperf_ssd_mobilenet_300_ssd_mobilenet_v1_coco_2018_01_28_cpu"),
|
||||
(xfail_issue_37957, "test_MSFT_opset10_mask_rcnn_keras_mask_rcnn_keras_cpu"),
|
||||
]
|
||||
for test_case in import_xfail_list:
|
||||
@@ -170,6 +168,7 @@ if len(zoo_models) > 0:
|
||||
(xfail_issue_48145, "test_onnx_model_zoo_text_machine_comprehension_bert_squad_model_bertsquad_8_download_sample_8_bertsquad8_cpu"),
|
||||
(xfail_issue_48190, "test_onnx_model_zoo_text_machine_comprehension_roberta_model_roberta_base_11_roberta_base_11_roberta_base_11_cpu"),
|
||||
(xfail_issue_onnx_models_140, "test_onnx_model_zoo_vision_object_detection_segmentation_duc_model_ResNet101_DUC_7_ResNet101_DUC_HDC_ResNet101_DUC_HDC_cpu"),
|
||||
(xfail_issue_63643, "test_onnx_model_zoo_vision_object_detection_segmentation_ssd_mobilenetv1_model_ssd_mobilenet_v1_10_ssd_mobilenet_v1_ssd_mobilenet_v1_cpu"),
|
||||
|
||||
# Model MSFT
|
||||
(xfail_issue_37973, "test_MSFT_opset7_tf_inception_v2_model_cpu"),
|
||||
@@ -187,7 +186,7 @@ if len(zoo_models) > 0:
|
||||
(xfail_issue_39669, "test_MSFT_opset9_cgan_cgan_cpu"),
|
||||
(xfail_issue_47495, "test_MSFT_opset10_BERT_Squad_bertsquad10_cpu"),
|
||||
(xfail_issue_45457, "test_MSFT_opset10_mlperf_ssd_resnet34_1200_ssd_resnet34_mAP_20.2_cpu"),
|
||||
|
||||
(xfail_issue_63643, "test_MSFT_opset10_mlperf_ssd_mobilenet_300_ssd_mobilenet_v1_coco_2018_01_28_cpu"),
|
||||
]
|
||||
for test_case in import_xfail_list + execution_xfail_list:
|
||||
xfail, test_name = test_case
|
||||
|
||||
Reference in New Issue
Block a user