[ONNX] Split importing model to two phases: decode and convert (#6326)
This commit is contained in:
parent
8fb1182670
commit
1a5bd8a510
@ -134,6 +134,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
|
||||
return Precision(Precision::BIN);
|
||||
case ::ngraph::element::Type_t::boolean:
|
||||
return Precision(Precision::BOOL);
|
||||
case ::ngraph::element::Type_t::dynamic:
|
||||
return Precision(Precision::UNSPECIFIED);
|
||||
default:
|
||||
IE_THROW() << "Incorrect precision " << precision.get_type_name() << "!"; return{};
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ class TRANSFORMATIONS_API FrameworkNode : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
explicit FrameworkNode(const OutputVector& inputs);
|
||||
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
|
@ -10,8 +10,9 @@ using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::FrameworkNode, "FrameworkNode", 0);
|
||||
|
||||
op::FrameworkNode::FrameworkNode(const OutputVector& inputs)
|
||||
op::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size)
|
||||
: Op(inputs) {
|
||||
set_output_size(output_size);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -495,6 +495,7 @@ std::string get_opset_name(
|
||||
std::string get_precision_name(const ngraph::element::Type & elem_type) {
|
||||
switch (elem_type) {
|
||||
case ::ngraph::element::Type_t::undefined:
|
||||
case ::ngraph::element::Type_t::dynamic:
|
||||
return "UNSPECIFIED";
|
||||
case ::ngraph::element::Type_t::f16:
|
||||
return "FP16";
|
||||
|
@ -45,7 +45,7 @@ if(COMMAND ie_faster_build)
|
||||
)
|
||||
endif()
|
||||
|
||||
target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder
|
||||
target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder inference_engine_transformations
|
||||
PUBLIC ngraph)
|
||||
|
||||
target_include_directories(onnx_importer PUBLIC $<BUILD_INTERFACE:${ONNX_IMPORT_INCLUDE_DIR}>
|
||||
|
@ -75,9 +75,8 @@ namespace ngraph
|
||||
|
||||
bool has_attribute(const std::string& name) const;
|
||||
|
||||
Subgraph get_subgraph_from_attribute(
|
||||
const std::string& name,
|
||||
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
|
||||
bool has_subgraph() const;
|
||||
std::shared_ptr<Subgraph> get_subgraph() const;
|
||||
|
||||
template <typename T>
|
||||
T get_attribute_value(const std::string& name, T default_value) const;
|
||||
|
@ -0,0 +1,100 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <core/graph.hpp>
|
||||
#include <ngraph/visibility.hpp>
|
||||
#include <ngraph_ops/framework_node.hpp>
|
||||
#include <onnx_import/core/node.hpp>
|
||||
|
||||
namespace ONNX_NAMESPACE
|
||||
{
|
||||
// forward declaration
|
||||
class ModelProto;
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
class Model;
|
||||
}
|
||||
|
||||
namespace frontend
|
||||
{
|
||||
class ONNXFrameworkNode : public op::FrameworkNode
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
ONNXFrameworkNode(const onnx_import::Node& node)
|
||||
: FrameworkNode(node.get_ng_inputs(), node.get_outputs_size())
|
||||
, m_node(node)
|
||||
{
|
||||
}
|
||||
|
||||
ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
|
||||
: FrameworkNode(inputs, node.get_outputs_size())
|
||||
, m_node(node)
|
||||
{
|
||||
}
|
||||
|
||||
const onnx_import::Node& get_onnx_node() const { return m_node; }
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& inputs) const override;
|
||||
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor) override
|
||||
{
|
||||
// TODO: implement reading as well, now it work for serialization only
|
||||
std::string domain = m_node.domain();
|
||||
std::string op_type = m_node.op_type();
|
||||
visitor.on_attribute("ONNX_META_domain", domain);
|
||||
visitor.on_attribute("ONNX_META_type", op_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
onnx_import::Node m_node;
|
||||
};
|
||||
|
||||
class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
ONNXSubgraphFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
|
||||
: ONNXFrameworkNode(node, inputs)
|
||||
{
|
||||
}
|
||||
|
||||
void infer_inputs_from_parent()
|
||||
{
|
||||
get_onnx_node().get_subgraph()->infer_inputs_from_parent();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_subgraph_body() const
|
||||
{
|
||||
auto subgraph = get_onnx_node().get_subgraph();
|
||||
return std::make_shared<Function>(subgraph->get_ng_outputs(),
|
||||
subgraph->get_ng_parameters(),
|
||||
subgraph->get_name());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
@ -11,9 +11,7 @@ namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
Subgraph Attribute::get_subgraph(
|
||||
const Graph& parent_graph,
|
||||
const std::map<std::size_t, std::string>& carried_dependencies_map) const
|
||||
Subgraph Attribute::get_subgraph(const Graph& parent_graph) const
|
||||
{
|
||||
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH)
|
||||
{
|
||||
@ -25,33 +23,6 @@ namespace ngraph
|
||||
const auto& graph = m_attribute_proto->g();
|
||||
model_proto->mutable_graph()->CopyFrom(graph);
|
||||
|
||||
const std::size_t subgraph_inputs_count =
|
||||
static_cast<size_t>(model_proto->mutable_graph()->mutable_input()->size());
|
||||
// Use the `carried_dependencies_map` to infer the types for the subgraph inputs
|
||||
for (const auto& carried_dependency : carried_dependencies_map)
|
||||
{
|
||||
if (carried_dependency.first >= subgraph_inputs_count)
|
||||
{
|
||||
NGRAPH_WARN << "Input with index: '" << carried_dependency.first
|
||||
<< "' was not found in the subgraph";
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& parent_in =
|
||||
parent_graph.get_ng_node_from_cache(carried_dependency.second);
|
||||
const auto& carried_type = parent_in.get_element_type();
|
||||
auto subgraph_in =
|
||||
model_proto->mutable_graph()->mutable_input(carried_dependency.first);
|
||||
auto subgraph_in_tensor_type =
|
||||
subgraph_in->mutable_type()->mutable_tensor_type();
|
||||
if (!subgraph_in_tensor_type->has_elem_type())
|
||||
{
|
||||
subgraph_in_tensor_type->set_elem_type(
|
||||
onnx_common::ng_to_onnx_data_type(carried_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set opset version and domain from the parent graph
|
||||
model_proto->mutable_opset_import()->CopyFrom(parent_graph.get_opset_imports());
|
||||
auto model = common::make_unique<Model>(std::move(model_proto));
|
||||
|
@ -316,9 +316,7 @@ namespace ngraph
|
||||
float get_float() const { return m_attribute_proto->f(); }
|
||||
int64_t get_integer() const { return m_attribute_proto->i(); }
|
||||
const std::string& get_string() const { return m_attribute_proto->s(); }
|
||||
Subgraph get_subgraph(
|
||||
const Graph& parent_graph,
|
||||
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
|
||||
Subgraph get_subgraph(const Graph& parent_graph) const;
|
||||
|
||||
std::vector<Tensor> get_tensor_array() const
|
||||
{
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
#include "onnx_import/onnx_framework_node.hpp"
|
||||
#include "utils/common.hpp"
|
||||
#include "utils/provenance_tag.hpp"
|
||||
|
||||
@ -55,25 +56,6 @@ namespace ngraph
|
||||
Graph::Graph(std::unique_ptr<Model>&& model)
|
||||
: Graph(std::move(model), common::make_unique<GraphCache>())
|
||||
{
|
||||
// Remove dangling Parameters
|
||||
for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();)
|
||||
{
|
||||
if ((*param_it)->get_output_target_inputs(0).size() == 0)
|
||||
{
|
||||
const auto& name = (*param_it)->get_friendly_name();
|
||||
auto out_it = std::find_if(
|
||||
m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) {
|
||||
return info.get_name() == name;
|
||||
});
|
||||
if (out_it == m_outputs.end())
|
||||
{
|
||||
m_cache->remove_node(name);
|
||||
param_it = m_parameters.erase(param_it);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
param_it++;
|
||||
}
|
||||
}
|
||||
|
||||
Graph::Graph(std::unique_ptr<Model>&& model, std::unique_ptr<GraphCache>&& cache)
|
||||
@ -174,14 +156,82 @@ namespace ngraph
|
||||
NGRAPH_CHECK(unknown_operators.empty(),
|
||||
"nGraph does not support the following ONNX operations: ",
|
||||
detail::to_string(unknown_operators));
|
||||
}
|
||||
|
||||
void Graph::convert_to_ngraph_nodes()
|
||||
{
|
||||
// Process ONNX graph nodes, convert to nGraph nodes
|
||||
for (const auto& node_proto : m_model->get_graph().node())
|
||||
{
|
||||
m_nodes.emplace_back(node_proto, *this);
|
||||
const Node& node{m_nodes.back()};
|
||||
|
||||
if (node.has_subgraph())
|
||||
{
|
||||
auto subgraph = node.get_subgraph();
|
||||
auto body_func = subgraph->convert();
|
||||
}
|
||||
OutputVector ng_nodes{node.get_ng_nodes()};
|
||||
set_friendly_names(node, ng_nodes);
|
||||
for (std::size_t i{0}; i < node.get_outputs_size(); ++i)
|
||||
{
|
||||
m_cache->emplace_node(node.output(i), std::move(ng_nodes.at(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Graph::remove_dangling_parameters()
|
||||
{
|
||||
for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();)
|
||||
{
|
||||
if ((*param_it)->get_output_target_inputs(0).size() == 0)
|
||||
{
|
||||
const auto& name = (*param_it)->get_friendly_name();
|
||||
auto out_it = std::find_if(
|
||||
m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) {
|
||||
return info.get_name() == name;
|
||||
});
|
||||
if (out_it == m_outputs.end())
|
||||
{
|
||||
m_cache->remove_node(name);
|
||||
param_it = m_parameters.erase(param_it);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
param_it++;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> Graph::convert()
|
||||
{
|
||||
convert_to_ngraph_nodes();
|
||||
remove_dangling_parameters();
|
||||
return create_function();
|
||||
}
|
||||
|
||||
void Graph::decode_to_framework_nodes()
|
||||
{
|
||||
// Process ONNX graph nodes, convert to nGraph nodes
|
||||
for (const auto& node_proto : m_model->get_graph().node())
|
||||
{
|
||||
m_nodes.emplace_back(node_proto, *this);
|
||||
const Node& node{m_nodes.back()};
|
||||
std::shared_ptr<frontend::ONNXFrameworkNode> framework_node;
|
||||
if (node.has_subgraph())
|
||||
{
|
||||
auto subgraph = node.get_subgraph();
|
||||
auto body_func = subgraph->decode();
|
||||
auto inputs = node.get_ng_inputs();
|
||||
for (const auto& input : subgraph->get_inputs_from_parent())
|
||||
inputs.push_back(input);
|
||||
framework_node =
|
||||
std::make_shared<ngraph::frontend::ONNXSubgraphFrameworkNode>(node, inputs);
|
||||
}
|
||||
else
|
||||
{
|
||||
framework_node = std::make_shared<ngraph::frontend::ONNXFrameworkNode>(node);
|
||||
}
|
||||
OutputVector ng_nodes{framework_node->outputs()};
|
||||
set_friendly_names(node, ng_nodes);
|
||||
// Iterate over the number of outputs for given node in graph.
|
||||
// Some of them may be optional and trimmed. See:
|
||||
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
|
||||
@ -192,12 +242,24 @@ namespace ngraph
|
||||
}
|
||||
}
|
||||
|
||||
const GraphCache& Graph::get_graph_cache() const { return *m_cache.get(); }
|
||||
bool Graph::is_node_in_cache(const std::string& name) const
|
||||
std::shared_ptr<Function> Graph::create_function()
|
||||
{
|
||||
return m_cache->contains(name);
|
||||
auto function = std::make_shared<Function>(get_ng_outputs(), m_parameters, get_name());
|
||||
for (std::size_t i{0}; i < function->get_output_size(); ++i)
|
||||
{
|
||||
function->get_output_op(i)->set_friendly_name(m_outputs.at(i).get_name());
|
||||
}
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> Graph::decode()
|
||||
{
|
||||
decode_to_framework_nodes();
|
||||
return create_function();
|
||||
}
|
||||
|
||||
const GraphCache& Graph::get_graph_cache() const { return *m_cache.get(); }
|
||||
|
||||
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
|
||||
{
|
||||
return m_cache->get_node(name);
|
||||
@ -247,6 +309,12 @@ namespace ngraph
|
||||
set_friendly_names(onnx_node, ng_node_vector);
|
||||
add_provenance_tags(onnx_node, ng_node_vector);
|
||||
|
||||
for (std::size_t i{0}; i < onnx_node.get_outputs_size(); ++i)
|
||||
{
|
||||
auto ng_node = ng_node_vector.at(i);
|
||||
m_cache->emplace_node(onnx_node.output(i), std::move(ng_node));
|
||||
}
|
||||
|
||||
return ng_node_vector;
|
||||
}
|
||||
|
||||
@ -323,9 +391,21 @@ namespace ngraph
|
||||
}
|
||||
|
||||
Subgraph::Subgraph(std::unique_ptr<Model>&& model, const Graph& parent_graph)
|
||||
: Graph(
|
||||
std::move(model),
|
||||
std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
|
||||
: Graph(std::move(model), common::make_unique<GraphCache>())
|
||||
, m_parent_graph_cache(&parent_graph.get_graph_cache())
|
||||
{
|
||||
}
|
||||
|
||||
Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) const
|
||||
{
|
||||
if (m_cache->contains(name))
|
||||
{
|
||||
return m_cache->get_node(name);
|
||||
}
|
||||
return m_parent_graph_cache->get_node(name);
|
||||
}
|
||||
|
||||
void Subgraph::find_inputs_from_parent()
|
||||
{
|
||||
// find all nodes on edge parent graph-subgraph
|
||||
// (it means input of node from parent graph, output from subgraph)
|
||||
@ -334,16 +414,16 @@ namespace ngraph
|
||||
int input_index = 0;
|
||||
for (const auto& in_name : node_proto.input())
|
||||
{
|
||||
if (m_cache->node_scope(in_name) == NodeScope::ParentGraph)
|
||||
if (m_parent_graph_cache->contains(in_name))
|
||||
{
|
||||
const auto& from_parent_node = m_cache->get_node(in_name);
|
||||
const auto& from_parent_node = m_parent_graph_cache->get_node(in_name);
|
||||
// constants are skipped
|
||||
if (!ngraph::is_type<ngraph::op::Constant>(
|
||||
from_parent_node.get_node_shared_ptr()))
|
||||
{
|
||||
for (const auto& out_name : node_proto.output())
|
||||
{
|
||||
if (m_cache->node_scope(out_name) == NodeScope::SubGraph)
|
||||
if (m_cache->contains(out_name))
|
||||
{
|
||||
auto out_node_to_replace_input = m_cache->get_node(out_name);
|
||||
auto new_param = std::make_shared<ngraph::op::Parameter>(
|
||||
@ -353,8 +433,10 @@ namespace ngraph
|
||||
out_node_to_replace_input.get_node()
|
||||
->input(input_index)
|
||||
.replace_source_output(new_param);
|
||||
m_parameter_to_parent_node_map.insert({new_param, in_name});
|
||||
m_cache->emplace_node(in_name, new_param);
|
||||
m_parameters.push_back(new_param);
|
||||
m_outputs_from_parent.push_back(from_parent_node);
|
||||
m_inputs_from_parent.push_back(in_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -364,11 +446,39 @@ namespace ngraph
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<Output<ngraph::Node>> Subgraph::get_outputs_from_parent() const
|
||||
std::shared_ptr<Function> Subgraph::convert()
|
||||
{
|
||||
return m_outputs_from_parent;
|
||||
convert_to_ngraph_nodes();
|
||||
find_inputs_from_parent();
|
||||
return create_function();
|
||||
}
|
||||
|
||||
void Subgraph::decode_to_framework_nodes()
|
||||
{
|
||||
Graph::decode_to_framework_nodes();
|
||||
find_inputs_from_parent();
|
||||
}
|
||||
|
||||
const std::vector<Output<ngraph::Node>> Subgraph::get_inputs_from_parent() const
|
||||
{
|
||||
OutputVector result;
|
||||
for (const auto& name : m_inputs_from_parent)
|
||||
{
|
||||
result.push_back(m_parent_graph_cache->get_node(name));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void Subgraph::infer_inputs_from_parent()
|
||||
{
|
||||
for (auto& it : m_parameter_to_parent_node_map)
|
||||
{
|
||||
const auto& node = m_parent_graph_cache->get_node(it.second);
|
||||
auto& parameter = it.first;
|
||||
parameter->set_element_type(node.get_element_type());
|
||||
parameter->set_partial_shape(node.get_partial_shape());
|
||||
}
|
||||
}
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -31,13 +31,14 @@ namespace ngraph
|
||||
|
||||
Graph& operator=(const Graph&) = delete;
|
||||
Graph& operator=(Graph&&) = default;
|
||||
virtual std::shared_ptr<Function> convert();
|
||||
std::shared_ptr<Function> decode();
|
||||
const std::vector<Node>& get_nodes() const { return m_nodes; }
|
||||
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
|
||||
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
|
||||
OutputVector get_ng_outputs() const;
|
||||
const ParameterVector& get_ng_parameters() const { return m_parameters; }
|
||||
bool is_node_in_cache(const std::string& name) const;
|
||||
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
|
||||
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
|
||||
const std::string& get_name() const { return m_model->get_graph().name(); }
|
||||
OutputVector make_ng_nodes(const Node& onnx_node) const;
|
||||
const GraphCache& get_graph_cache() const;
|
||||
@ -60,6 +61,11 @@ namespace ngraph
|
||||
const OutputVector& ng_node_vector) const;
|
||||
|
||||
protected:
|
||||
virtual void decode_to_framework_nodes();
|
||||
void convert_to_ngraph_nodes();
|
||||
void remove_dangling_parameters();
|
||||
std::shared_ptr<Function> create_function();
|
||||
|
||||
ParameterVector m_parameters;
|
||||
std::unique_ptr<Model> m_model;
|
||||
std::unique_ptr<GraphCache> m_cache;
|
||||
@ -82,9 +88,11 @@ namespace ngraph
|
||||
/// \param[in] parent_graph The reference to the parent graph.
|
||||
Subgraph(std::unique_ptr<Model>&& model, const Graph& parent_graph);
|
||||
|
||||
/// \brief Return outputs which are on the edge the subgraph and the parent graph.
|
||||
/// \brief Return nodes which are on the edge the subgraph and the parent graph.
|
||||
/// \return Vector of edge nodes from parent scope.
|
||||
const std::vector<Output<ngraph::Node>> get_outputs_from_parent() const;
|
||||
const std::vector<Output<ngraph::Node>> get_inputs_from_parent() const;
|
||||
|
||||
std::shared_ptr<Function> convert() override;
|
||||
|
||||
Subgraph() = delete;
|
||||
|
||||
@ -94,8 +102,17 @@ namespace ngraph
|
||||
Subgraph& operator=(const Subgraph&) = delete;
|
||||
Subgraph& operator=(Subgraph&&) = default;
|
||||
|
||||
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const override;
|
||||
void infer_inputs_from_parent();
|
||||
|
||||
private:
|
||||
std::vector<Output<ngraph::Node>> m_outputs_from_parent;
|
||||
void decode_to_framework_nodes() override;
|
||||
void find_inputs_from_parent();
|
||||
|
||||
const GraphCache* m_parent_graph_cache;
|
||||
std::vector<std::string> m_inputs_from_parent;
|
||||
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::string>
|
||||
m_parameter_to_parent_node_map;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
|
||||
|
@ -39,55 +39,5 @@ namespace ngraph
|
||||
{
|
||||
return (m_graph_cache_map.count(name) > 0);
|
||||
}
|
||||
|
||||
NodeScope GraphCache::node_scope(const std::string& name) const
|
||||
{
|
||||
return contains(name) ? NodeScope::ParentGraph : NodeScope::Lack;
|
||||
}
|
||||
|
||||
SubgraphCache::SubgraphCache(const GraphCache& parent_graph_cache)
|
||||
: m_parent_graph_cache{&parent_graph_cache}
|
||||
{
|
||||
if (m_parent_graph_cache == nullptr)
|
||||
{
|
||||
throw ngraph_error("Parent graph cache is not initialized");
|
||||
}
|
||||
}
|
||||
|
||||
Output<ngraph::Node> SubgraphCache::get_node(const std::string& name) const
|
||||
{
|
||||
// present in subgraph scope
|
||||
if (GraphCache::contains(name))
|
||||
{
|
||||
return GraphCache::get_node(name);
|
||||
}
|
||||
else // present in parent graph scope
|
||||
{
|
||||
return m_parent_graph_cache->get_node(name);
|
||||
}
|
||||
}
|
||||
|
||||
bool SubgraphCache::contains(const std::string& name) const
|
||||
{
|
||||
// the node is in subgraph or in parent graph scope
|
||||
return GraphCache::contains(name) || m_parent_graph_cache->contains(name);
|
||||
}
|
||||
|
||||
NodeScope SubgraphCache::node_scope(const std::string& name) const
|
||||
{
|
||||
if (GraphCache::contains(name))
|
||||
{
|
||||
return NodeScope::SubGraph;
|
||||
}
|
||||
else if (m_parent_graph_cache->contains(name))
|
||||
{
|
||||
return NodeScope::ParentGraph;
|
||||
}
|
||||
else
|
||||
{
|
||||
return NodeScope::Lack;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -14,17 +14,6 @@ namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
/// \brief Enum which determines scope (visibility) of nodes in GraphCache.
|
||||
enum class NodeScope
|
||||
{
|
||||
// in parent graph scope
|
||||
ParentGraph = 1,
|
||||
// in subgraph scope
|
||||
SubGraph,
|
||||
// not available at all
|
||||
Lack
|
||||
};
|
||||
|
||||
/// \brief GraphCache stores and provides access to ONNX graph initializers.
|
||||
class GraphCache
|
||||
{
|
||||
@ -58,58 +47,10 @@ namespace ngraph
|
||||
/// \return true if the node named `name` exist in the cache, false otherwise.
|
||||
virtual bool contains(const std::string& name) const;
|
||||
|
||||
/// \brief Return NodeScope enum which determines scope of the node.
|
||||
/// \note If the method is called on GraphCache the ParentGraph enum
|
||||
/// value is retunred always.
|
||||
///
|
||||
/// \param[in] name The name of the node.
|
||||
///
|
||||
/// \return SubGraph if node belongs to SubgraphCache, ParentGraph if
|
||||
/// is avalible in parent_graph_cache, otherwise Lack
|
||||
virtual NodeScope node_scope(const std::string& name) const;
|
||||
|
||||
virtual ~GraphCache() = default;
|
||||
|
||||
private:
|
||||
std::map<std::string, Output<ngraph::Node>> m_graph_cache_map;
|
||||
};
|
||||
|
||||
class SubgraphCache : public GraphCache
|
||||
{
|
||||
public:
|
||||
/// \brief Constructs a SubgraphCache class object.
|
||||
///
|
||||
/// \param[in] parent_graph_cache The reference to the parent graph.
|
||||
SubgraphCache(const GraphCache& parent_graph_cache);
|
||||
|
||||
/// \brief Get the node from the cache (subgraph or parent graph)
|
||||
///
|
||||
/// \note If the node is not found the ngraph_error exception is thrown.
|
||||
///
|
||||
/// \param[in] name The name of the node.
|
||||
///
|
||||
/// \return The node named `name` from subgraph (as present) or from parent graph.
|
||||
Output<ngraph::Node> get_node(const std::string& name) const override;
|
||||
|
||||
/// \brief Return true if the node named `name` exist in the cache.
|
||||
///
|
||||
/// \param[in] name The name of the node.
|
||||
///
|
||||
/// \return true if the node named `name` exist in the cache
|
||||
/// (subgraph or parent graph), false otherwise.
|
||||
bool contains(const std::string& name) const override;
|
||||
|
||||
/// \brief Return NodeScope enum which determines scope of the node.
|
||||
///
|
||||
/// \param[in] name The name of the node.
|
||||
///
|
||||
/// \return SubGraph if the node belongs to SubgraphCache, ParentGraph if
|
||||
/// is avalible in parent_graph_cache, otherwise Lack
|
||||
NodeScope node_scope(const std::string& name) const override;
|
||||
|
||||
private:
|
||||
const GraphCache* m_parent_graph_cache;
|
||||
};
|
||||
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "core/model.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "onnx_import/onnx_framework_node.hpp"
|
||||
#include "ops_bridge.hpp"
|
||||
|
||||
namespace ngraph
|
||||
|
@ -26,6 +26,29 @@ namespace ngraph
|
||||
, m_graph{&graph}
|
||||
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
|
||||
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
|
||||
{
|
||||
const auto it =
|
||||
std::find_if(std::begin(m_attributes),
|
||||
std::end(m_attributes),
|
||||
[&](const Attribute& attribute) { return attribute.is_graph(); });
|
||||
m_has_subgraph = it != std::end(m_attributes);
|
||||
if (m_has_subgraph)
|
||||
{
|
||||
m_subgraph = std::make_shared<Subgraph>(it->get_subgraph(*m_graph));
|
||||
}
|
||||
}
|
||||
|
||||
Impl(const ONNX_NAMESPACE::NodeProto& node_proto,
|
||||
const Graph& graph,
|
||||
std::shared_ptr<Subgraph> subgraph)
|
||||
: m_node_proto{&node_proto}
|
||||
, m_name{node_proto.has_name() ? node_proto.name() : ""}
|
||||
, m_domain{get_node_domain(node_proto)}
|
||||
, m_graph{&graph}
|
||||
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
|
||||
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
|
||||
, m_has_subgraph(subgraph != nullptr)
|
||||
, m_subgraph(subgraph)
|
||||
{
|
||||
}
|
||||
|
||||
@ -44,9 +67,8 @@ namespace ngraph
|
||||
|
||||
bool has_attribute(const std::string& name) const;
|
||||
|
||||
Subgraph get_subgraph_from_attribute(
|
||||
const std::string& name,
|
||||
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
|
||||
bool has_subgraph() const;
|
||||
std::shared_ptr<Subgraph> get_subgraph() const;
|
||||
|
||||
template <typename T>
|
||||
T get_attribute_value(const std::string& name, T default_value) const;
|
||||
@ -58,6 +80,8 @@ namespace ngraph
|
||||
const Graph& graph() const;
|
||||
|
||||
private:
|
||||
Subgraph get_subgraph_from_attribute(const std::string& name) const;
|
||||
|
||||
const ONNX_NAMESPACE::NodeProto* m_node_proto;
|
||||
std::string m_name;
|
||||
std::string m_domain;
|
||||
@ -65,6 +89,9 @@ namespace ngraph
|
||||
std::vector<Attribute> m_attributes;
|
||||
std::vector<std::reference_wrapper<const std::string>> m_output_names;
|
||||
mutable std::string m_description;
|
||||
|
||||
bool m_has_subgraph;
|
||||
std::shared_ptr<Subgraph> m_subgraph;
|
||||
};
|
||||
|
||||
const ONNX_NAMESPACE::NodeProto& Node::Impl::node_proto() const { return *m_node_proto; }
|
||||
@ -94,9 +121,7 @@ namespace ngraph
|
||||
return it != std::end(m_attributes);
|
||||
}
|
||||
|
||||
Subgraph Node::Impl::get_subgraph_from_attribute(
|
||||
const std::string& name,
|
||||
const std::map<std::size_t, std::string>& carried_dependencies_map) const
|
||||
Subgraph Node::Impl::get_subgraph_from_attribute(const std::string& name) const
|
||||
{
|
||||
auto it = std::find_if(
|
||||
std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
|
||||
@ -106,9 +131,13 @@ namespace ngraph
|
||||
{
|
||||
throw error::node::UnknownAttribute{this->name(), name};
|
||||
}
|
||||
return it->get_subgraph(graph(), carried_dependencies_map);
|
||||
return it->get_subgraph(*m_graph);
|
||||
}
|
||||
|
||||
bool Node::Impl::has_subgraph() const { return m_has_subgraph; }
|
||||
|
||||
std::shared_ptr<Subgraph> Node::Impl::get_subgraph() const { return m_subgraph; }
|
||||
|
||||
template <typename T>
|
||||
T Node::Impl::get_attribute_value(const std::string& name, T default_value) const
|
||||
{
|
||||
@ -140,8 +169,7 @@ namespace ngraph
|
||||
template <>
|
||||
Subgraph Node::Impl::get_attribute_value(const std::string& name) const
|
||||
{
|
||||
const std::map<std::size_t, std::string> empty_map;
|
||||
return get_subgraph_from_attribute(name, empty_map);
|
||||
return get_subgraph_from_attribute(name);
|
||||
}
|
||||
|
||||
OutputVector Node::Impl::get_ng_nodes(const Node& node) const
|
||||
@ -196,7 +224,9 @@ namespace ngraph
|
||||
}
|
||||
|
||||
Node::Node(const Node& other)
|
||||
: m_pimpl{new Impl{other.m_pimpl->node_proto(), other.m_pimpl->graph()},
|
||||
: m_pimpl{new Impl{other.m_pimpl->node_proto(),
|
||||
other.m_pimpl->graph(),
|
||||
other.get_subgraph()},
|
||||
[](Impl* impl) { delete impl; }}
|
||||
{
|
||||
}
|
||||
@ -219,12 +249,9 @@ namespace ngraph
|
||||
return m_pimpl->has_attribute(name);
|
||||
}
|
||||
|
||||
Subgraph Node::get_subgraph_from_attribute(
|
||||
const std::string& name,
|
||||
const std::map<std::size_t, std::string>& carried_dependencies_map) const
|
||||
{
|
||||
return m_pimpl->get_subgraph_from_attribute(name, carried_dependencies_map);
|
||||
}
|
||||
bool Node::has_subgraph() const { return m_pimpl->has_subgraph(); }
|
||||
|
||||
std::shared_ptr<Subgraph> Node::get_subgraph() const { return m_pimpl->get_subgraph(); }
|
||||
|
||||
std::vector<std::string> Node::get_attribute_names() const
|
||||
{
|
||||
@ -462,7 +489,6 @@ namespace ngraph
|
||||
{
|
||||
return m_pimpl->template get_attribute_value<std::vector<Graph>>(name);
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -36,7 +36,10 @@ namespace ngraph
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"NullNode", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NullNode() = default;
|
||||
NullNode()
|
||||
: Node(1)
|
||||
{
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
@ -19,20 +19,6 @@ namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
namespace error
|
||||
{
|
||||
namespace value_info
|
||||
{
|
||||
struct unspecified_element_type : ngraph_error
|
||||
{
|
||||
unspecified_element_type()
|
||||
: ngraph_error{"value info has no element type specified"}
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace value_info
|
||||
} // namespace error
|
||||
|
||||
class ValueInfo
|
||||
{
|
||||
public:
|
||||
@ -65,12 +51,12 @@ namespace ngraph
|
||||
const PartialShape& get_shape() const { return m_partial_shape; }
|
||||
const element::Type& get_element_type() const
|
||||
{
|
||||
if (!m_value_info_proto->type().tensor_type().has_elem_type())
|
||||
if (m_value_info_proto->type().tensor_type().has_elem_type())
|
||||
{
|
||||
throw error::value_info::unspecified_element_type{};
|
||||
return common::get_ngraph_element_type(
|
||||
m_value_info_proto->type().tensor_type().elem_type());
|
||||
}
|
||||
return common::get_ngraph_element_type(
|
||||
m_value_info_proto->type().tensor_type().elem_type());
|
||||
return ngraph::element::dynamic;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node>
|
||||
|
34
ngraph/frontend/onnx_import/src/onnx_framework_node.cpp
Normal file
34
ngraph/frontend/onnx_import/src/onnx_framework_node.cpp
Normal file
@ -0,0 +1,34 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <onnx_import/onnx_framework_node.hpp>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace frontend
|
||||
{
|
||||
NGRAPH_RTTI_DEFINITION(ONNXFrameworkNode, "ONNXFrameworkNode", 1);
|
||||
|
||||
std::shared_ptr<Node>
|
||||
ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const
|
||||
{
|
||||
return std::make_shared<ONNXFrameworkNode>(m_node, inputs);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1);
|
||||
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
@ -77,10 +77,18 @@ namespace ngraph
|
||||
loop_carried_dependencies[i].get_node()->get_friendly_name();
|
||||
}
|
||||
|
||||
const Subgraph& body_graph{
|
||||
node.get_subgraph_from_attribute("body", loop_carried_dependencies_map)};
|
||||
auto body_outputs = body_graph.get_ng_outputs();
|
||||
const auto& body_inputs = body_graph.get_ng_parameters();
|
||||
auto body_graph = node.get_subgraph();
|
||||
auto body_outputs = body_graph->get_ng_outputs();
|
||||
const auto& body_inputs = body_graph->get_ng_parameters();
|
||||
|
||||
// Infer loop body inputs' element type based on carried dependencies
|
||||
for (size_t i = 0; i < loop_carried_dependencies.size(); i++)
|
||||
{
|
||||
body_inputs[i + 2]->set_element_type(
|
||||
loop_carried_dependencies[i].get_element_type());
|
||||
body_inputs[i + 2]->set_partial_shape(
|
||||
loop_carried_dependencies[i].get_partial_shape());
|
||||
}
|
||||
|
||||
// optional inputs
|
||||
Output<ngraph::Node> trip_count;
|
||||
@ -190,22 +198,22 @@ namespace ngraph
|
||||
final_values.push_back(loop->get_iter_value(*body_outputs_it++, -1));
|
||||
}
|
||||
|
||||
const auto& outputs_from_parent = body_graph.get_outputs_from_parent();
|
||||
const auto& inputs_from_parent = body_graph->get_inputs_from_parent();
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
static_cast<size_t>(std::distance(body_inputs_it, body_inputs.end())) ==
|
||||
outputs_from_parent.size(),
|
||||
inputs_from_parent.size(),
|
||||
"Expected number of invariant parameters is"
|
||||
" not equal number of provided outputs from parent scope");
|
||||
" not equal number of provided inputs from parent scope");
|
||||
|
||||
// Set-up parameters from parent graph which are not changed during Loop's
|
||||
// iterations
|
||||
for (auto out_from_parent_it = outputs_from_parent.begin();
|
||||
for (auto in_from_parent_it = inputs_from_parent.begin();
|
||||
body_inputs_it != body_inputs.end() &&
|
||||
out_from_parent_it != outputs_from_parent.end();
|
||||
++body_inputs_it, ++out_from_parent_it)
|
||||
in_from_parent_it != inputs_from_parent.end();
|
||||
++body_inputs_it, ++in_from_parent_it)
|
||||
{
|
||||
loop->set_invariant_input(*body_inputs_it, *out_from_parent_it);
|
||||
loop->set_invariant_input(*body_inputs_it, *in_from_parent_it);
|
||||
}
|
||||
|
||||
// Set-up scan outputs
|
||||
|
@ -6,7 +6,9 @@
|
||||
|
||||
#include "core/graph.hpp"
|
||||
#include "core/model.hpp"
|
||||
#include "core/null_node.hpp"
|
||||
#include "core/transform.hpp"
|
||||
#include "onnx_import/onnx_framework_node.hpp"
|
||||
#include "onnx_import/utils/onnx_internal.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -15,21 +17,81 @@ namespace ngraph
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
std::shared_ptr<Function>
|
||||
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
|
||||
void remove_dangling_parameters(std::shared_ptr<Function>& function)
|
||||
{
|
||||
auto p_model_proto = common::make_unique<ONNX_NAMESPACE::ModelProto>(model_proto);
|
||||
auto model = common::make_unique<Model>(std::move(p_model_proto));
|
||||
|
||||
Graph graph{std::move(model)};
|
||||
auto function = std::make_shared<Function>(
|
||||
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
|
||||
for (std::size_t i{0}; i < function->get_output_size(); ++i)
|
||||
const auto parameters = function->get_parameters();
|
||||
for (auto parameter : parameters)
|
||||
{
|
||||
function->get_output_op(i)->set_friendly_name(
|
||||
graph.get_outputs().at(i).get_name());
|
||||
const auto parameter_users = parameter->get_users();
|
||||
// if a Parameter is connected to a ONNXFrameworkNode that was not converted
|
||||
// during convert_function it means, this Parameter is dangling and we can
|
||||
// remove it from function
|
||||
const bool is_dangling_parameter = std::all_of(
|
||||
parameter_users.begin(),
|
||||
parameter_users.end(),
|
||||
[](const std::shared_ptr<ngraph::Node>& node) -> bool {
|
||||
return std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node) !=
|
||||
nullptr;
|
||||
});
|
||||
if (is_dangling_parameter)
|
||||
{
|
||||
function->remove_parameter(parameter);
|
||||
}
|
||||
}
|
||||
return function;
|
||||
}
|
||||
|
||||
void remove_dangling_results(std::shared_ptr<Function>& function)
|
||||
{
|
||||
const auto results = function->get_results();
|
||||
for (auto result : results)
|
||||
{
|
||||
// we can remove Result from function if after function conversion,
|
||||
// Result is connected to NullNode only
|
||||
const auto result_inputs = result->input_values();
|
||||
const bool is_dangling_result =
|
||||
std::all_of(result_inputs.begin(),
|
||||
result_inputs.end(),
|
||||
[](const Output<ngraph::Node>& node) -> bool {
|
||||
return ngraph::op::is_null(node);
|
||||
});
|
||||
if (is_dangling_result)
|
||||
{
|
||||
function->remove_result(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void convert_decoded_function(std::shared_ptr<Function> function)
|
||||
{
|
||||
for (const auto& node : function->get_ordered_ops())
|
||||
{
|
||||
if (auto raw_node =
|
||||
std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node))
|
||||
{
|
||||
if (auto subgraph_node =
|
||||
std::dynamic_pointer_cast<frontend::ONNXSubgraphFrameworkNode>(
|
||||
node))
|
||||
{
|
||||
subgraph_node->infer_inputs_from_parent();
|
||||
convert_decoded_function(subgraph_node->get_subgraph_body());
|
||||
}
|
||||
const auto& onnx_node = raw_node->get_onnx_node();
|
||||
OutputVector ng_nodes{onnx_node.get_ng_nodes()};
|
||||
if (ng_nodes.size() > raw_node->get_output_size())
|
||||
{
|
||||
ng_nodes.resize(raw_node->get_output_size());
|
||||
}
|
||||
replace_node(raw_node, ng_nodes);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Have to revalidate node because new intpus can affect shape/type
|
||||
// propagation for already translated nodes
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
}
|
||||
remove_dangling_parameters(function);
|
||||
remove_dangling_results(function);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
@ -39,7 +101,10 @@ namespace ngraph
|
||||
transform::fixup_legacy_operators(model_proto);
|
||||
transform::update_external_data_paths(model_proto, model_path);
|
||||
|
||||
return detail::convert_to_ng_function(model_proto);
|
||||
auto p_model_proto = common::make_unique<ONNX_NAMESPACE::ModelProto>(model_proto);
|
||||
auto model = common::make_unique<Model>(std::move(p_model_proto));
|
||||
Graph graph{std::move(model)};
|
||||
return graph.convert();
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace onnx_import
|
||||
|
@ -390,8 +390,7 @@ def test_cast_errors():
|
||||
for name, value in zip(node.input, [input_data])
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info(name, onnx.TensorProto.FLOAT16, value.shape)
|
||||
for name, value in zip(node.output, ())
|
||||
make_tensor_value_info(node.output[0], onnx.TensorProto.FLOAT16, input_data.shape)
|
||||
] # type: ignore
|
||||
|
||||
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
|
||||
@ -406,8 +405,7 @@ def test_cast_errors():
|
||||
for name, value in zip(node.input, [input_data])
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape)
|
||||
for name, value in zip(node.output, ())
|
||||
make_tensor_value_info(node.output[0], onnx.TensorProto.INT32, input_data.shape)
|
||||
] # type: ignore
|
||||
|
||||
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
|
||||
@ -422,8 +420,7 @@ def test_cast_errors():
|
||||
for name, value in zip(node.input, [input_data])
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape)
|
||||
for name, value in zip(node.output, ())
|
||||
make_tensor_value_info(node.output[0], onnx.TensorProto.INT32, input_data.shape)
|
||||
] # type: ignore
|
||||
|
||||
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
|
||||
@ -438,8 +435,7 @@ def test_cast_errors():
|
||||
for name, value in zip(node.input, [input_data])
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info(name, onnx.TensorProto.COMPLEX128, value.shape)
|
||||
for name, value in zip(node.output, ())
|
||||
make_tensor_value_info(node.output[0], onnx.TensorProto.COMPLEX128, input_data.shape)
|
||||
] # type: ignore
|
||||
|
||||
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
|
||||
|
@ -2,7 +2,6 @@ ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "target_shape"
|
||||
output: "output"
|
||||
op_type: "ConstantFill"
|
||||
attribute {
|
||||
|
Loading…
Reference in New Issue
Block a user