NodeVector -> OutputVector replacement (#1272)

This commit is contained in:
Katarzyna Mitrus 2020-07-29 17:18:56 +02:00 committed by GitHub
parent dec7df17ed
commit f34511642a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
294 changed files with 882 additions and 900 deletions

View File

@ -23,8 +23,8 @@ void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
return false; return false;
} }
auto last_node = batch_to_space->decompose_op()[0]; auto last_node = batch_to_space->decompose_op()[0];
last_node->set_friendly_name(batch_to_space->get_friendly_name()); last_node.get_node()->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::replace_node(batch_to_space, last_node); ngraph::replace_node(batch_to_space, last_node.get_node_shared_ptr());
return true; return true;
}; };

View File

@ -23,8 +23,8 @@ void ngraph::pass::ConvertSpaceToBatch::convert_space_to_batch() {
return false; return false;
} }
auto last_node = space_to_batch->decompose_op()[0]; auto last_node = space_to_batch->decompose_op()[0];
last_node->set_friendly_name(space_to_batch->get_friendly_name()); last_node.get_node()->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::replace_node(space_to_batch, last_node); ngraph::replace_node(space_to_batch, last_node.get_node_shared_ptr());
return true; return true;
}; };

View File

@ -68,7 +68,7 @@ Output<Node> builder::MatmulFactory::get_right()
return m_inputs.at(1); return m_inputs.at(1);
} }
NodeVector builder::MatmulFactory::make_matmul_op() OutputVector builder::MatmulFactory::make_matmul_op()
{ {
auto left = get_left(); auto left = get_left();
auto right = get_right(); auto right = get_right();

View File

@ -40,8 +40,8 @@ namespace ngraph
/// \brief Create a sub-graph representing an ONNX MatMul operation. /// \brief Create a sub-graph representing an ONNX MatMul operation.
/// ///
/// \return NodeVector containing the sub-graph output node. /// \return OutputVector containing the sub-graph output node.
virtual NodeVector make_matmul_op(); virtual OutputVector make_matmul_op();
protected: protected:
/// \return Output representing the left operand. /// \return Output representing the left operand.

View File

@ -172,8 +172,9 @@ namespace ngraph
} }
} }
void void check_concat(const OutputVector& args,
check_concat(const NodeVector& args, const NodeVector& mins, const NodeVector& maxs) const OutputVector& mins,
const OutputVector& maxs)
{ {
auto size = args.size(); auto size = args.size();
if (size != mins.size() || size != maxs.size()) if (size != mins.size() || size != maxs.size())
@ -184,17 +185,17 @@ namespace ngraph
{ {
auto min = mins[i]; auto min = mins[i];
auto max = maxs[i]; auto max = maxs[i];
auto type = min->get_element_type(); auto type = min.get_element_type();
if (type != max->get_element_type()) if (type != max.get_element_type())
{ {
throw ngraph_error("check_concat: min and max must have same type"); throw ngraph_error("check_concat: min and max must have same type");
} }
if (min->get_shape() != Shape{1} || max->get_shape() != Shape{1}) if (min.get_shape() != Shape{1} || max.get_shape() != Shape{1})
{ {
throw ngraph_error("check_concat: min/max shape not Shape{1}: " + throw ngraph_error("check_concat: min/max shape not Shape{1}: " +
vector_to_string(min->get_shape()) + vector_to_string(min.get_shape()) +
vector_to_string(max->get_shape())); vector_to_string(max.get_shape()));
} }
} }
} }

View File

@ -64,9 +64,9 @@ namespace ngraph
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const bool requantize = true); const bool requantize = true);
void check_concat(const NodeVector& args, void check_concat(const OutputVector& args,
const NodeVector& mins, const OutputVector& mins,
const NodeVector& maxs); const OutputVector& maxs);
} }
} }
} }

View File

@ -25,28 +25,28 @@ namespace ngraph
{ {
namespace builder namespace builder
{ {
shared_ptr<Node> QuantizedConcatBuilder(const NodeVector& args, shared_ptr<Node> QuantizedConcatBuilder(const OutputVector& args,
size_t concatenation_axis, size_t concatenation_axis,
const NodeVector& mins, const OutputVector& mins,
const NodeVector& maxs) const OutputVector& maxs)
{ {
quantization_utils::check_concat(args, mins, maxs); quantization_utils::check_concat(args, mins, maxs);
auto quant_type = args[0]->get_element_type(); auto quant_type = args[0].get_element_type();
// output scale // output scale
auto min = make_shared<op::Min>(make_shared<op::Concat>(mins, 0), ngraph::AxisSet{0}); auto min = make_shared<op::Min>(make_shared<op::Concat>(mins, 0), ngraph::AxisSet{0});
auto max = make_shared<op::Max>(make_shared<op::Concat>(maxs, 0), ngraph::AxisSet{0}); auto max = make_shared<op::Max>(make_shared<op::Concat>(maxs, 0), ngraph::AxisSet{0});
auto out_scale = quantization_utils::get_scale(min, max, quant_type); auto out_scale = quantization_utils::get_scale(min, max, quant_type);
NodeVector rescaled_args(args.size()); OutputVector rescaled_args(args.size());
for (size_t i = 0; i < args.size(); ++i) for (size_t i = 0; i < args.size(); ++i)
{ {
auto q_type = args[i]->get_element_type(); auto q_type = args[i].get_element_type();
auto in_scale = make_shared<ngraph::op::Reshape>( auto in_scale = make_shared<ngraph::op::Reshape>(
quantization_utils::get_scale(mins[i], maxs[i], q_type), quantization_utils::get_scale(mins[i], maxs[i], q_type),
AxisVector{0}, AxisVector{0},
Shape{}); Shape{});
auto zero = make_constant(q_type, in_scale->get_shape(), 0); auto zero = make_constant(q_type, in_scale->get_output_shape(0), 0);
rescaled_args[i] = rescaled_args[i] =
make_shared<op::Dequantize>(args[i], in_scale, zero, element::f32, AxisSet{}); make_shared<op::Dequantize>(args[i], in_scale, zero, element::f32, AxisSet{});
@ -58,7 +58,7 @@ namespace ngraph
AxisSet{}, AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
} }
OutputVector base = as_output_vector(args); OutputVector base = args;
for (auto node : mins) for (auto node : mins)
{ {
base.push_back(node); base.push_back(node);

View File

@ -32,9 +32,9 @@ namespace ngraph
namespace builder namespace builder
{ {
NGRAPH_API NGRAPH_API
std::shared_ptr<Node> QuantizedConcatBuilder(const NodeVector& args, std::shared_ptr<Node> QuantizedConcatBuilder(const OutputVector& args,
size_t concatenation_axis, size_t concatenation_axis,
const NodeVector& mins, const OutputVector& mins,
const NodeVector& maxs); const OutputVector& maxs);
} }
} }

View File

@ -47,37 +47,13 @@ namespace
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds) std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output})); ->add_provenance_group_members_above({output}));
} }
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<op::GetOutputElement>(node, i);
}
}
return outputs;
}
} }
NodeVector builder::split(const Output<ngraph::Node>& value, OutputVector
const std::vector<size_t>& length_parts, builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
size_t axis)
{ {
size_t start_index{0}; size_t start_index{0};
NodeVector outputs; OutputVector outputs;
for (const auto& length_part : length_parts) for (const auto& length_part : length_parts)
{ {
size_t end_index{start_index + length_part}; size_t end_index{start_index + length_part};
@ -87,7 +63,7 @@ NodeVector builder::split(const Output<ngraph::Node>& value,
return outputs; return outputs;
} }
NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axis) OutputVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
{ {
size_t axis_to_split{static_cast<size_t>(axis)}; size_t axis_to_split{static_cast<size_t>(axis)};
if (axis < 0) if (axis < 0)
@ -100,9 +76,9 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
return split(value, length_parts, axis_to_split); return split(value, length_parts, axis_to_split);
} }
NodeVector builder::opset1::split(const Output<Node>& value, OutputVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths, const std::vector<size_t>& split_lengths,
int64_t axis) int64_t axis)
{ {
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis}); const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split_lengths_node = const auto split_lengths_node =
@ -110,13 +86,13 @@ NodeVector builder::opset1::split(const Output<Node>& value,
const auto variadic_split = const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node); std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return get_outputs(variadic_split); return variadic_split->outputs();
} }
NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis) OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{ {
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis}); const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits); const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
return get_outputs(split); return split->outputs();
} }

View File

@ -31,9 +31,9 @@ namespace ngraph
/// ///
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple nodes we split input node into.
/// ///
NodeVector split(const Output<Node>& value, OutputVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts, const std::vector<size_t>& length_parts,
size_t axis = 0); size_t axis = 0);
/// \brief Split node on specified axis into multiple parts. /// \brief Split node on specified axis into multiple parts.
/// ///
@ -47,9 +47,9 @@ namespace ngraph
/// indexing). This means that the axis to split on will be counted from /// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank). /// the back of the tensor (negative values are subtracted from its rank).
/// ///
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple outputs we split input node into.
/// ///
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0); OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
namespace opset1 namespace opset1
{ {
@ -63,13 +63,13 @@ namespace ngraph
/// indexing). This means that the axis to split on will be counted from /// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank). /// the back of the tensor (negative values are subtracted from its rank).
/// ///
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple outputs we split input node into.
/// The vector is output of Split:v1 op /// The vector is output of Split:v1 op
/// ///
NGRAPH_API NGRAPH_API
NodeVector split(const Output<Node>& value, OutputVector split(const Output<Node>& value,
const std::vector<size_t>& split_lengths, const std::vector<size_t>& split_lengths,
int64_t axis = 0); int64_t axis = 0);
/// \brief Split value on specified axis into multiple parts. /// \brief Split value on specified axis into multiple parts.
/// ///
@ -88,7 +88,7 @@ namespace ngraph
/// The vector is output of VariadicSplit:v1 op /// The vector is output of VariadicSplit:v1 op
/// ///
NGRAPH_API NGRAPH_API
NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0); OutputVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
} }
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph

View File

@ -157,7 +157,7 @@ namespace ngraph
m_nodes.emplace_back(node_proto, *this); m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()}; const Node& node{m_nodes.back()};
NodeVector ng_nodes{node.get_ng_nodes()}; OutputVector ng_nodes{node.get_ng_nodes()};
// Iterate over the number of outputs for given node in graph. // Iterate over the number of outputs for given node in graph.
// Some of them may be optional and trimmed. See: // Some of them may be optional and trimmed. See:
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs // https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
@ -174,14 +174,14 @@ namespace ngraph
return m_cache->contains(name); return m_cache->contains(name);
} }
std::shared_ptr<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
{ {
return m_cache->get_node(name); return m_cache->get_node(name);
} }
NodeVector Graph::get_ng_outputs() const OutputVector Graph::get_ng_outputs() const
{ {
NodeVector results; OutputVector results;
for (const auto& output : m_graph_proto->output()) for (const auto& output : m_graph_proto->output())
{ {
results.emplace_back(get_ng_node_from_cache(output.name())); results.emplace_back(get_ng_node_from_cache(output.name()));
@ -189,11 +189,11 @@ namespace ngraph
return results; return results;
} }
NodeVector Graph::make_ng_nodes(const Node& onnx_node) const OutputVector Graph::make_ng_nodes(const Node& onnx_node) const
{ {
const auto ng_node_factory = const auto ng_node_factory =
m_model->get_operator(onnx_node.op_type(), onnx_node.domain()); m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
NodeVector ng_node_vector; OutputVector ng_node_vector;
try try
{ {
ng_node_vector = ng_node_factory(onnx_node); ng_node_vector = ng_node_factory(onnx_node);
@ -223,7 +223,7 @@ namespace ngraph
} }
void Graph::set_friendly_names(const Node& onnx_node, void Graph::set_friendly_names(const Node& onnx_node,
const NodeVector& ng_node_vector) const const OutputVector& ng_node_vector) const
{ {
for (int i = 0; i < ng_node_vector.size(); ++i) for (int i = 0; i < ng_node_vector.size(); ++i)
{ {
@ -234,7 +234,7 @@ namespace ngraph
break; break;
} }
ng_node_vector[i]->set_friendly_name(onnx_node.output(i)); ng_node_vector[i].get_node()->set_friendly_name(onnx_node.output(i));
} }
} }
@ -267,7 +267,7 @@ namespace ngraph
} }
void Graph::add_provenance_tags(const Node& onnx_node, void Graph::add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector) const const OutputVector& ng_node_vector) const
{ {
if (!ngraph::get_provenance_enabled()) if (!ngraph::get_provenance_enabled())
{ {
@ -278,9 +278,9 @@ namespace ngraph
const auto ng_inputs = onnx_node.get_ng_inputs(); const auto ng_inputs = onnx_node.get_ng_inputs();
ngraph::traverse_nodes( ngraph::traverse_nodes(
ng_node_vector, as_node_vector(ng_node_vector),
[&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); }, [&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); },
ng_inputs); as_node_vector(ng_inputs));
} }
Subgraph::Subgraph(const ONNX_NAMESPACE::GraphProto& proto, Subgraph::Subgraph(const ONNX_NAMESPACE::GraphProto& proto,

View File

@ -39,12 +39,12 @@ namespace ngraph
const std::vector<Node>& get_nodes() const { return m_nodes; } const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; } const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; } const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
NodeVector get_ng_outputs() const; OutputVector get_ng_outputs() const;
const ParameterVector& get_ng_parameters() const { return m_parameters; } const ParameterVector& get_ng_parameters() const { return m_parameters; }
bool is_node_in_cache(const std::string& name) const; bool is_node_in_cache(const std::string& name) const;
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const; Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
const std::string& get_name() const { return m_graph_proto->name(); } const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& onnx_node) const; OutputVector make_ng_nodes(const Node& onnx_node) const;
const GraphCache& get_graph_cache() const; const GraphCache& get_graph_cache() const;
protected: protected:
@ -52,7 +52,8 @@ namespace ngraph
Model& model, Model& model,
std::unique_ptr<GraphCache>&& cache); std::unique_ptr<GraphCache>&& cache);
void set_friendly_names(const Node& onnx_node, const NodeVector& ng_node_vector) const; void set_friendly_names(const Node& onnx_node,
const OutputVector& ng_node_vector) const;
void add_provenance_tag_to_initializer( void add_provenance_tag_to_initializer(
const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const; const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const;
@ -60,7 +61,8 @@ namespace ngraph
void add_provenance_tag_to_input(const ValueInfo& input, void add_provenance_tag_to_input(const ValueInfo& input,
std::shared_ptr<ngraph::Node> node) const; std::shared_ptr<ngraph::Node> node) const;
void add_provenance_tags(const Node& onnx_node, const NodeVector& ng_node_vector) const; void add_provenance_tags(const Node& onnx_node,
const OutputVector& ng_node_vector) const;
private: private:
const ONNX_NAMESPACE::GraphProto* m_graph_proto; const ONNX_NAMESPACE::GraphProto* m_graph_proto;

View File

@ -21,12 +21,12 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
void GraphCache::emplace_node(const std::string& name, std::shared_ptr<ngraph::Node>&& node) void GraphCache::emplace_node(const std::string& name, Output<ngraph::Node>&& node)
{ {
m_graph_cache_map[name] = std::move(node); m_graph_cache_map[name] = std::move(node);
} }
std::shared_ptr<ngraph::Node> GraphCache::get_node(const std::string& name) const Output<ngraph::Node> GraphCache::get_node(const std::string& name) const
{ {
try try
{ {
@ -52,7 +52,7 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node> SubgraphCache::get_node(const std::string& name) const Output<ngraph::Node> SubgraphCache::get_node(const std::string& name) const
{ {
// present in subgraph scope // present in subgraph scope
if (GraphCache::contains(name)) if (GraphCache::contains(name))

View File

@ -35,7 +35,7 @@ namespace ngraph
/// ///
/// \param[in] name The name of node added to the cache. /// \param[in] name The name of node added to the cache.
/// \param[in] node The node added to the cache. /// \param[in] node The node added to the cache.
void emplace_node(const std::string& name, std::shared_ptr<ngraph::Node>&& node); void emplace_node(const std::string& name, Output<ngraph::Node>&& node);
/// \brief Get the node from the cache /// \brief Get the node from the cache
/// ///
@ -44,7 +44,7 @@ namespace ngraph
/// \param[in] name The name of the node. /// \param[in] name The name of the node.
/// ///
/// \return The node named `name`. /// \return The node named `name`.
virtual std::shared_ptr<ngraph::Node> get_node(const std::string& name) const; virtual Output<ngraph::Node> get_node(const std::string& name) const;
/// \brief Return true if the node named `name` exist in the cache. /// \brief Return true if the node named `name` exist in the cache.
/// ///
@ -54,7 +54,7 @@ namespace ngraph
virtual bool contains(const std::string& name) const; virtual bool contains(const std::string& name) const;
private: private:
std::map<std::string, std::shared_ptr<ngraph::Node>> m_graph_cache_map; std::map<std::string, Output<ngraph::Node>> m_graph_cache_map;
}; };
class SubgraphCache : public GraphCache class SubgraphCache : public GraphCache
@ -72,7 +72,7 @@ namespace ngraph
/// \param[in] name The name of the node. /// \param[in] name The name of the node.
/// ///
/// \return The node named `name` from subgraph (as present) or from parent graph. /// \return The node named `name` from subgraph (as present) or from parent graph.
std::shared_ptr<ngraph::Node> get_node(const std::string& name) const override; Output<ngraph::Node> get_node(const std::string& name) const override;
/// \brief Return true if the node named `name` exist in the cache. /// \brief Return true if the node named `name` exist in the cache.
/// ///

View File

@ -40,8 +40,8 @@ namespace ngraph
} }
const std::vector<Attribute>& attributes() const; const std::vector<Attribute>& attributes() const;
NodeVector get_ng_nodes(const Node& node) const; OutputVector get_ng_nodes(const Node& node) const;
NodeVector get_ng_inputs() const; OutputVector get_ng_inputs() const;
const std::string& domain() const; const std::string& domain() const;
const std::string& op_type() const; const std::string& op_type() const;
@ -140,14 +140,14 @@ namespace ngraph
return it->get_subgraph(graph()); return it->get_subgraph(graph());
} }
NodeVector Node::Impl::get_ng_nodes(const Node& node) const OutputVector Node::Impl::get_ng_nodes(const Node& node) const
{ {
return m_graph->make_ng_nodes(node); return m_graph->make_ng_nodes(node);
} }
NodeVector Node::Impl::get_ng_inputs() const OutputVector Node::Impl::get_ng_inputs() const
{ {
NodeVector result; OutputVector result;
for (const auto& name : m_node_proto->input()) for (const auto& name : m_node_proto->input())
{ {
if (!name.empty()) if (!name.empty())
@ -156,7 +156,7 @@ namespace ngraph
} }
else else
{ {
result.push_back(std::make_shared<NullNode>()); result.push_back(std::make_shared<NullNode>()->output(0));
} }
} }
return result; return result;
@ -197,8 +197,8 @@ namespace ngraph
{ {
} }
NodeVector Node::get_ng_inputs() const { return m_pimpl->get_ng_inputs(); } OutputVector Node::get_ng_inputs() const { return m_pimpl->get_ng_inputs(); }
NodeVector Node::get_ng_nodes() const { return m_pimpl->get_ng_nodes(*this); } OutputVector Node::get_ng_nodes() const { return m_pimpl->get_ng_nodes(*this); }
const std::string& Node::domain() const { return m_pimpl->domain(); } const std::string& Node::domain() const { return m_pimpl->domain(); }
const std::string& Node::op_type() const { return m_pimpl->op_type(); } const std::string& Node::op_type() const { return m_pimpl->op_type(); }
const std::string& Node::get_description() const { return m_pimpl->description(); } const std::string& Node::get_description() const { return m_pimpl->description(); }

View File

@ -64,8 +64,8 @@ namespace ngraph
Node& operator=(Node&&) noexcept = delete; Node& operator=(Node&&) noexcept = delete;
Node& operator=(const Node&) = delete; Node& operator=(const Node&) = delete;
NodeVector get_ng_inputs() const; OutputVector get_ng_inputs() const;
NodeVector get_ng_nodes() const; OutputVector get_ng_nodes() const;
const std::string& domain() const; const std::string& domain() const;
const std::string& op_type() const; const std::string& op_type() const;
const std::string& get_name() const; const std::string& get_name() const;

View File

@ -42,3 +42,8 @@ bool ngraph::op::is_null(const std::shared_ptr<ngraph::Node>& node)
{ {
return is_null(node.get()); return is_null(node.get());
} }
bool ngraph::op::is_null(const Output<ngraph::Node>& output)
{
return is_null(output.get_node());
}

View File

@ -29,6 +29,8 @@ namespace ngraph
bool is_null(const ngraph::Node* node); bool is_null(const ngraph::Node* node);
ONNX_IMPORTER_API ONNX_IMPORTER_API
bool is_null(const std::shared_ptr<ngraph::Node>& node); bool is_null(const std::shared_ptr<ngraph::Node>& node);
ONNX_IMPORTER_API
bool is_null(const Output<ngraph::Node>& output);
} }
namespace onnx_import namespace onnx_import
{ {

View File

@ -28,7 +28,7 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
/// \brief Function which transforms single ONNX operator to nGraph sub-graph. /// \brief Function which transforms single ONNX operator to nGraph sub-graph.
using Operator = std::function<NodeVector(const Node&)>; using Operator = std::function<OutputVector(const Node&)>;
/// \brief Map which contains ONNX operators accessible by std::string value as a key. /// \brief Map which contains ONNX operators accessible by std::string value as a key.
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>; using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector abs(const Node& node) inline OutputVector abs(const Node& node)
{ {
return {std::make_shared<default_opset::Abs>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Abs>(node.get_ng_inputs().at(0))};
} }

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector acos(const Node& node) inline OutputVector acos(const Node& node)
{ {
return {std::make_shared<default_opset::Acos>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Acos>(node.get_ng_inputs().at(0))};
} }

View File

@ -28,7 +28,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector acosh(const Node& node) inline OutputVector acosh(const Node& node)
{ {
return {std::make_shared<default_opset::Acosh>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Acosh>(node.get_ng_inputs().at(0))};
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector add(const Node& node) OutputVector add(const Node& node)
{ {
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0); const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1); Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
@ -45,7 +45,7 @@ namespace ngraph
namespace set_7 namespace set_7
{ {
NodeVector add(const Node& node) OutputVector add(const Node& node)
{ {
return {std::make_shared<default_opset::Add>(node.get_ng_inputs().at(0), return {std::make_shared<default_opset::Add>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))}; node.get_ng_inputs().at(1))};

View File

@ -29,13 +29,13 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector add(const Node& node); OutputVector add(const Node& node);
} // namespace set_1 } // namespace set_1
namespace set_7 namespace set_7
{ {
NodeVector add(const Node& node); OutputVector add(const Node& node);
} // namespace set_7 } // namespace set_7

View File

@ -31,7 +31,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector logical_and(const Node& node) inline OutputVector logical_and(const Node& node)
{ {
return {std::make_shared<default_opset::LogicalAnd>( return {std::make_shared<default_opset::LogicalAnd>(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))}; node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};

View File

@ -25,7 +25,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector argmax(const Node& node) OutputVector argmax(const Node& node)
{ {
const utils::ArgMinMaxFactory arg_factory(node); const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_max()}; return {arg_factory.make_arg_max()};

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing an Ngraph node which produces the output /// \return The vector containing an Ngraph node which produces the output
/// of an ONNX ArgMax operation. /// of an ONNX ArgMax operation.
NodeVector argmax(const Node& node); OutputVector argmax(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -25,7 +25,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector argmin(const Node& node) OutputVector argmin(const Node& node)
{ {
const utils::ArgMinMaxFactory arg_factory(node); const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_min()}; return {arg_factory.make_arg_min()};

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing an Ngraph node which produces the output /// \return The vector containing an Ngraph node which produces the output
/// of an ONNX ArgMin operation. /// of an ONNX ArgMin operation.
NodeVector argmin(const Node& node); OutputVector argmin(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -31,7 +31,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector asin(const Node& node) inline OutputVector asin(const Node& node)
{ {
return {std::make_shared<default_opset::Asin>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Asin>(node.get_ng_inputs().at(0))};
} }

View File

@ -28,7 +28,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector asinh(const Node& node) inline OutputVector asinh(const Node& node)
{ {
return {std::make_shared<default_opset::Asinh>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Asinh>(node.get_ng_inputs().at(0))};
} }

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector atan(const Node& node) inline OutputVector atan(const Node& node)
{ {
return {std::make_shared<default_opset::Atan>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Atan>(node.get_ng_inputs().at(0))};
} }

View File

@ -28,7 +28,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector atanh(const Node& node) inline OutputVector atanh(const Node& node)
{ {
return {std::make_shared<default_opset::Atanh>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Atanh>(node.get_ng_inputs().at(0))};
} }

View File

@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector average_pool(const Node& node) OutputVector average_pool(const Node& node)
{ {
return pooling::LocalPoolingFactory(node).make_avg_pool(); return pooling::LocalPoolingFactory(node).make_avg_pool();
} }

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX AveragePool /// \return The vector containing Ngraph nodes producing output of ONNX AveragePool
/// operation. /// operation.
NodeVector average_pool(const Node& node); OutputVector average_pool(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -30,14 +30,14 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector batch_norm(const Node& node) OutputVector batch_norm(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; OutputVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0); auto x = inputs.at(0);
auto scale = inputs.at(1); auto scale = inputs.at(1);
auto bias = inputs.at(2); auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr}; Output<ngraph::Node> mean;
std::shared_ptr<ngraph::Node> var{nullptr}; Output<ngraph::Node> var;
std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)}; std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)}; double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector batch_norm(const Node& node); OutputVector batch_norm(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -28,7 +28,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cast(const Node& node) OutputVector cast(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
int64_t target_type = node.get_attribute_value<int64_t>("to"); int64_t target_type = node.get_attribute_value<int64_t>("to");

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cast(const Node& node); OutputVector cast(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector ceil(const Node& node) inline OutputVector ceil(const Node& node)
{ {
return {std::make_shared<default_opset::Ceiling>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Ceiling>(node.get_ng_inputs().at(0))};
} }

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector clip(const Node& node) OutputVector clip(const Node& node)
{ {
const auto data = node.get_ng_inputs().at(0); const auto data = node.get_ng_inputs().at(0);
@ -47,14 +47,14 @@ namespace ngraph
namespace set_11 namespace set_11
{ {
NodeVector clip(const Node& node) OutputVector clip(const Node& node)
{ {
const NodeVector inputs{node.get_ng_inputs()}; const OutputVector inputs{node.get_ng_inputs()};
const std::shared_ptr<ngraph::Node> data = inputs.at(0); const Output<ngraph::Node> data = inputs.at(0);
const element::Type data_type = data->get_element_type(); const element::Type data_type = data.get_element_type();
const Shape data_shape = data->get_shape(); const Shape data_shape = data.get_shape();
std::shared_ptr<ngraph::Node> min; Output<ngraph::Node> min;
std::shared_ptr<ngraph::Node> max; Output<ngraph::Node> max;
// If second input is provided, assign to min input, otherwise set lowest // If second input is provided, assign to min input, otherwise set lowest
// numeric limit of double as min input. // numeric limit of double as min input.

View File

@ -27,13 +27,13 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector clip(const Node& node); OutputVector clip(const Node& node);
} // namespace set_1 } // namespace set_1
namespace set_11 namespace set_11
{ {
NodeVector clip(const Node& node); OutputVector clip(const Node& node);
} // namespace set_11 } // namespace set_11

View File

@ -30,9 +30,9 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector concat(const Node& node) OutputVector concat(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; OutputVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis"); std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
return {std::make_shared<default_opset::Concat>(inputs, axis)}; return {std::make_shared<default_opset::Concat>(inputs, axis)};
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector concat(const Node& node); OutputVector concat(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -153,7 +153,7 @@ namespace ngraph
} }
} }
NodeVector constant(const onnx_import::Node& node) OutputVector constant(const onnx_import::Node& node)
{ {
return {make_constant(node.get_attribute_value<Tensor>("value"))}; return {make_constant(node.get_attribute_value<Tensor>("value"))};
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector constant(const Node& node); OutputVector constant(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -28,9 +28,9 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector constant_of_shape(const onnx_import::Node& node) OutputVector constant_of_shape(const onnx_import::Node& node)
{ {
std::shared_ptr<ngraph::Node> constant_value; Output<ngraph::Node> constant_value;
if (node.has_attribute("value")) if (node.has_attribute("value"))
{ {
auto value_tensor = node.get_attribute_value<Tensor>("value"); auto value_tensor = node.get_attribute_value<Tensor>("value");

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector constant_of_shape(const Node& node); OutputVector constant_of_shape(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -38,8 +38,8 @@ namespace ngraph
namespace namespace
{ {
std::shared_ptr<ngraph::op::Op> std::shared_ptr<ngraph::op::Op>
make_ng_convolution(const std::shared_ptr<ngraph::Node>& data, make_ng_convolution(const Output<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters, const Output<ngraph::Node>& filters,
const ngraph::Strides& strides, const ngraph::Strides& strides,
const ngraph::Strides& dilations, const ngraph::Strides& dilations,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
@ -49,7 +49,7 @@ namespace ngraph
{ {
if (groups > 1) if (groups > 1)
{ {
auto filters_shape = filters->get_shape(); auto filters_shape = filters.get_shape();
filters_shape.at(0) = filters_shape.at(0) / groups; filters_shape.at(0) = filters_shape.at(0) / groups;
filters_shape.insert(filters_shape.begin(), groups); filters_shape.insert(filters_shape.begin(), groups);
@ -77,18 +77,16 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node> add_bias(const Output<ngraph::Node>& ng_conv,
add_bias(const std::shared_ptr<ngraph::Node>& ng_conv, const Output<ngraph::Node>& bias)
const std::shared_ptr<ngraph::Node>& bias)
{ {
const auto rank_of_conv = const auto rank_of_conv = ng_conv.get_partial_shape().rank().get_length();
ng_conv->get_output_partial_shape(0).rank().get_length();
// reshape the bias node {M} to {1, M, 1, 1, ..., 1} // reshape the bias node {M} to {1, M, 1, 1, ..., 1}
// this is required by the addition operation that needs to be able // this is required by the addition operation that needs to be able
// to broadcast the bias to match the shape of the convolution node // to broadcast the bias to match the shape of the convolution node
std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U); std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U);
reshape_pattern_values[1] = bias->get_shape().front(); reshape_pattern_values[1] = bias.get_shape().front();
const auto reshape_pattern = const auto reshape_pattern =
default_opset::Constant::create(element::u64, default_opset::Constant::create(element::u64,
Shape{reshape_pattern_values.size()}, Shape{reshape_pattern_values.size()},
@ -101,16 +99,16 @@ namespace ngraph
} }
} // namespace } // namespace
NodeVector conv(const Node& node) OutputVector conv(const Node& node)
{ {
// in the current implementation we assume that the data input rank is static // in the current implementation we assume that the data input rank is static
// and only the 'batch' dimension can be dynamic // and only the 'batch' dimension can be dynamic
const NodeVector& inputs = node.get_ng_inputs(); const OutputVector& inputs = node.get_ng_inputs();
const auto data = inputs.at(0); const auto data = inputs.at(0);
const auto filters = inputs.at(1); const auto filters = inputs.at(1);
const auto groups = node.get_attribute_value<int64_t>("group", 1); const auto groups = node.get_attribute_value<int64_t>("group", 1);
NGRAPH_CHECK(data->get_output_partial_shape(0).rank().is_static(), NGRAPH_CHECK(data.get_partial_shape().rank().is_static(),
"The input data tensor's rank has to be known (static)"); "The input data tensor's rank has to be known (static)");
const auto strides = convpool::get_strides(node); const auto strides = convpool::get_strides(node);
@ -137,7 +135,7 @@ namespace ngraph
else else
{ {
const auto bias = inputs.at(2); const auto bias = inputs.at(2);
const auto bias_ps = bias->get_output_partial_shape(0); const auto bias_ps = bias.get_partial_shape();
NGRAPH_CHECK(bias_ps.is_static() && is_vector(bias_ps.to_shape()), NGRAPH_CHECK(bias_ps.is_static() && is_vector(bias_ps.to_shape()),
"The bias input needs to be a static 1D vector"); "The bias input needs to be a static 1D vector");

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX convolution /// \return The vector containing Ngraph nodes producing output of ONNX convolution
/// operation. /// operation.
NodeVector conv(const Node& node); OutputVector conv(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// Disabled in CMakeList
// Update to higher opset required
#include "conv_integer.hpp" #include "conv_integer.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
@ -31,9 +34,9 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector conv_integer(const Node& node) OutputVector conv_integer(const Node& node)
{ {
const NodeVector& inputs = node.get_ng_inputs(); const OutputVector& inputs = node.get_ng_inputs();
auto num_inputs = inputs.size(); auto num_inputs = inputs.size();
auto input = inputs.at(0); auto input = inputs.at(0);
auto filters = inputs.at(1); auto filters = inputs.at(1);
@ -51,19 +54,18 @@ namespace ngraph
ngraph::op::PadType auto_pad_type = convpool::get_auto_pad(node); ngraph::op::PadType auto_pad_type = convpool::get_auto_pad(node);
auto& padding_below = paddings.first; auto& padding_below = paddings.first;
auto& padding_above = paddings.second; auto& padding_above = paddings.second;
convpool::calculate_auto_pads(input->get_shape(), convpool::calculate_auto_pads(input.get_shape(),
filters->get_shape(), filters.get_shape(),
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
auto_pad_type, auto_pad_type,
padding_below, padding_below,
padding_above); padding_above);
const Strides default_data_dilation_strides(input->get_shape().size() - 2, 1); const Strides default_data_dilation_strides(input.get_shape().size() - 2, 1);
auto scale_one = make_constant(ngraph::element::f32, Shape{}, 1); auto scale_one = make_constant(ngraph::element::f32, Shape{}, 1);
auto input_zero_point = make_constant(input->get_element_type(), Shape{}, 0); auto input_zero_point = make_constant(input.get_element_type(), Shape{}, 0);
auto filters_zero_point = auto filters_zero_point = make_constant(filters.get_element_type(), Shape{}, 0);
make_constant(filters->get_element_type(), Shape{}, 0);
auto output_zero_point = make_constant(ngraph::element::i32, Shape{}, 0); auto output_zero_point = make_constant(ngraph::element::i32, Shape{}, 0);
if (num_inputs == 2) if (num_inputs == 2)

View File

@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// Disabled in CMakeList
// Update to higher opset required
#pragma once #pragma once
#include "core/node.hpp" #include "core/node.hpp"
@ -33,7 +36,7 @@ namespace ngraph
/// ///
/// \return The vector containing Ngraph nodes producing output of quantized ONNX /// \return The vector containing Ngraph nodes producing output of quantized ONNX
/// convolution operation. /// convolution operation.
NodeVector conv_integer(const Node& node); OutputVector conv_integer(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -44,9 +44,9 @@ namespace ngraph
{ {
namespace namespace
{ {
std::shared_ptr<ngraph::Node> Output<ngraph::Node>
make_group_conv_backprop(const std::shared_ptr<ngraph::Node>& data, make_group_conv_backprop(const Output<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters, const Output<ngraph::Node>& filters,
const Strides& strides, const Strides& strides,
const Strides& dilations, const Strides& dilations,
const CoordinateDiff& pads_begin, const CoordinateDiff& pads_begin,
@ -83,9 +83,9 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node> Output<ngraph::Node>
make_conv_backprop(const std::shared_ptr<ngraph::Node>& data, make_conv_backprop(const Output<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters, const Output<ngraph::Node>& filters,
const Strides& strides, const Strides& strides,
const Strides& dilations, const Strides& dilations,
const CoordinateDiff& pads_begin, const CoordinateDiff& pads_begin,
@ -124,10 +124,9 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node> Output<ngraph::Node> get_reshaped_filters(const Output<ngraph::Node>& filters,
get_reshaped_filters(const std::shared_ptr<ngraph::Node>& filters, const PartialShape& filters_pshape,
const PartialShape& filters_pshape, int64_t groups)
int64_t groups)
{ {
if (filters_pshape.is_static()) if (filters_pshape.is_static())
{ {
@ -180,12 +179,11 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node> Output<ngraph::Node> get_prepared_bias(const Output<ngraph::Node>& bias,
get_prepared_bias(const std::shared_ptr<ngraph::Node>& bias, const Output<ngraph::Node>& conv)
const std::shared_ptr<ngraph::Node>& conv)
{ {
// Prepare bias shape [1, C, 1, 1] // Prepare bias shape [1, C, 1, 1]
const auto& conv_pshape = conv->get_output_partial_shape(0); const auto& conv_pshape = conv.get_partial_shape();
std::shared_ptr<ngraph::Node> bias_shape_node; std::shared_ptr<ngraph::Node> bias_shape_node;
if (conv_pshape.rank().is_static() && conv_pshape[1].is_static()) if (conv_pshape.rank().is_static() && conv_pshape[1].is_static())
@ -231,9 +229,9 @@ namespace ngraph
} }
} }
NodeVector conv_transpose(const Node& node) OutputVector conv_transpose(const Node& node)
{ {
const NodeVector& inputs = node.get_ng_inputs(); const OutputVector& inputs = node.get_ng_inputs();
CHECK_VALID_NODE(node, CHECK_VALID_NODE(node,
inputs.size() == 2 || inputs.size() == 3, inputs.size() == 2 || inputs.size() == 3,
@ -243,8 +241,8 @@ namespace ngraph
auto data = inputs[0]; auto data = inputs[0];
auto filters = inputs[1]; auto filters = inputs[1];
const auto& data_pshape = data->get_output_partial_shape(0); const auto& data_pshape = data.get_partial_shape();
const auto& filters_pshape = filters->get_output_partial_shape(0); const auto& filters_pshape = filters.get_partial_shape();
std::size_t num_spatial_dims = 0; std::size_t num_spatial_dims = 0;
Strides strides, dilations; Strides strides, dilations;
@ -291,7 +289,7 @@ namespace ngraph
CHECK_VALID_NODE( CHECK_VALID_NODE(
node, groups >= 0, "Incorrect value of 'group' attribute: ", groups); node, groups >= 0, "Incorrect value of 'group' attribute: ", groups);
std::shared_ptr<ngraph::Node> conv_node; Output<ngraph::Node> conv_node;
// reshape filters to match desired shape: // reshape filters to match desired shape:
// [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1] // [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX convolution /// \return The vector containing Ngraph nodes producing output of ONNX convolution
/// operation. /// operation.
NodeVector conv_transpose(const Node& node); OutputVector conv_transpose(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cos(const Node& node) OutputVector cos(const Node& node)
{ {
return {std::make_shared<default_opset::Cos>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Cos>(node.get_ng_inputs().at(0))};
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cos(const Node& node); OutputVector cos(const Node& node);
} }
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cosh(const Node& node) OutputVector cosh(const Node& node)
{ {
return {std::make_shared<default_opset::Cosh>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Cosh>(node.get_ng_inputs().at(0))};
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cosh(const Node& node); OutputVector cosh(const Node& node);
} }
} }
} // namespace onnx_import } // namespace onnx_import

View File

@ -27,13 +27,13 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cum_sum(const Node& node) OutputVector cum_sum(const Node& node)
{ {
auto inputs = node.get_ng_inputs(); auto inputs = node.get_ng_inputs();
auto data = inputs.at(0); auto data = inputs.at(0);
bool exclusive = node.get_attribute_value<std::int64_t>("exclusive", 0); bool exclusive = node.get_attribute_value<std::int64_t>("exclusive", 0);
bool reverse = node.get_attribute_value<std::int64_t>("reverse", 0); bool reverse = node.get_attribute_value<std::int64_t>("reverse", 0);
std::shared_ptr<ngraph::Node> axis; Output<ngraph::Node> axis;
if (inputs.size() > 1) if (inputs.size() > 1)
{ {
@ -44,7 +44,7 @@ namespace ngraph
axis = axis =
default_opset::Constant::create(element::i64, Shape{}, {0}); // default default_opset::Constant::create(element::i64, Shape{}, {0}); // default
} }
return NodeVector{ return OutputVector{
std::make_shared<default_opset::CumSum>(data, axis, exclusive, reverse)}; std::make_shared<default_opset::CumSum>(data, axis, exclusive, reverse)};
} }

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector cum_sum(const Node& node); OutputVector cum_sum(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -25,7 +25,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector depth_to_space(const Node& node) OutputVector depth_to_space(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
const auto mode = node.get_attribute_value<std::string>("mode", "DCR"); const auto mode = node.get_attribute_value<std::string>("mode", "DCR");
@ -34,7 +34,7 @@ namespace ngraph
? default_opset::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST ? default_opset::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST
: default_opset::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST; : default_opset::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST;
const auto block_size = node.get_attribute_value<std::int64_t>("blocksize"); const auto block_size = node.get_attribute_value<std::int64_t>("blocksize");
return NodeVector{std::make_shared<default_opset::DepthToSpace>( return OutputVector{std::make_shared<default_opset::DepthToSpace>(
data, ngraph_mode, block_size)}; data, ngraph_mode, block_size)};
} }
} // namespace set_1 } // namespace set_1

View File

@ -34,9 +34,9 @@ namespace ngraph
/// ///
/// \param[in] node The ONNX input node describing operation. /// \param[in] node The ONNX input node describing operation.
/// ///
/// \return NodeVector containing Tensor with shape: /// \return OutputVector containing Tensor with shape:
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize] /// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
NodeVector depth_to_space(const Node& node); OutputVector depth_to_space(const Node& node);
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op

View File

@ -36,13 +36,13 @@ namespace ngraph
{ {
namespace namespace
{ {
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs) Output<ngraph::Node> get_zero_point(const OutputVector& inputs)
{ {
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2])) if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2]))
{ {
auto zero_point = inputs[2]; auto zero_point = inputs[2];
if (zero_point->get_element_type() != element::f32) if (zero_point.get_element_type() != element::f32)
{ {
zero_point = zero_point =
std::make_shared<default_opset::Convert>(zero_point, element::f32); std::make_shared<default_opset::Convert>(zero_point, element::f32);
@ -58,9 +58,9 @@ namespace ngraph
} }
namespace set_1 namespace set_1
{ {
NodeVector dequantize_linear(const Node& node) OutputVector dequantize_linear(const Node& node)
{ {
const NodeVector inputs{node.get_ng_inputs()}; const OutputVector inputs{node.get_ng_inputs()};
NGRAPH_CHECK( NGRAPH_CHECK(
2 <= inputs.size() && inputs.size() <= 3, 2 <= inputs.size() && inputs.size() <= 3,
@ -71,8 +71,9 @@ namespace ngraph
const auto scale = inputs[1]; const auto scale = inputs[1];
const auto zero_point = get_zero_point(inputs); const auto zero_point = get_zero_point(inputs);
common::validate_scalar_input("Dequantization scale", scale, {element::f32}); common::validate_scalar_input(
common::validate_scalar_input("Zero point", zero_point); "Dequantization scale", scale.get_node_shared_ptr(), {element::f32});
common::validate_scalar_input("Zero point", zero_point.get_node_shared_ptr());
const auto converted_x = const auto converted_x =
std::make_shared<default_opset::Convert>(x, element::f32); std::make_shared<default_opset::Convert>(x, element::f32);
@ -86,11 +87,11 @@ namespace ngraph
{ {
namespace namespace
{ {
void validate_scale(const std::shared_ptr<ngraph::Node> scale, void validate_scale(const Output<ngraph::Node> scale,
const std::shared_ptr<ngraph::Node> x, const Output<ngraph::Node> x,
const int64_t axis) const int64_t axis)
{ {
const auto& scale_shape = scale->get_output_partial_shape(0); const auto& scale_shape = scale.get_partial_shape();
NGRAPH_CHECK(scale_shape.rank().get_length() == 0 || NGRAPH_CHECK(scale_shape.rank().get_length() == 0 ||
scale_shape.rank().get_length() == 1, scale_shape.rank().get_length() == 1,
"Dequantization scale needs to be a scalar or a vector."); "Dequantization scale needs to be a scalar or a vector.");
@ -98,7 +99,7 @@ namespace ngraph
if (scale_shape.rank().get_length() == 1) if (scale_shape.rank().get_length() == 1)
{ {
const auto& scale_dim = scale_shape[0]; const auto& scale_dim = scale_shape[0];
const auto& x_shape = x->get_output_partial_shape(0); const auto& x_shape = x.get_partial_shape();
const auto& x_dim_at_axis = x_shape[axis]; const auto& x_dim_at_axis = x_shape[axis];
NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis), NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis),
@ -111,11 +112,11 @@ namespace ngraph
} }
} }
void validate_zero_point(const std::shared_ptr<ngraph::Node> zero_point, void validate_zero_point(const Output<ngraph::Node> zero_point,
const std::shared_ptr<ngraph::Node> x, const Output<ngraph::Node> x,
const int64_t axis) const int64_t axis)
{ {
const auto& zero_point_shape = zero_point->get_output_partial_shape(0); const auto& zero_point_shape = zero_point.get_partial_shape();
NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 || NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 ||
zero_point_shape.rank().get_length() == 1, zero_point_shape.rank().get_length() == 1,
"Zero point needs to be a scalar or a vector."); "Zero point needs to be a scalar or a vector.");
@ -123,7 +124,7 @@ namespace ngraph
if (zero_point_shape.rank().get_length() == 1) if (zero_point_shape.rank().get_length() == 1)
{ {
const auto& zero_point_dim = zero_point_shape[0]; const auto& zero_point_dim = zero_point_shape[0];
const auto& x_shape = x->get_output_partial_shape(0); const auto& x_shape = x.get_partial_shape();
const auto& x_dim_at_axis = x_shape[axis]; const auto& x_dim_at_axis = x_shape[axis];
NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis), NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis),
@ -136,10 +137,9 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node> reshape_input(const Output<ngraph::Node> input,
reshape_input(const std::shared_ptr<ngraph::Node> input, const int64_t axis,
const int64_t axis, const PartialShape& x_shape)
const PartialShape& x_shape)
{ {
std::vector<int64_t> target_dims; std::vector<int64_t> target_dims;
@ -170,9 +170,9 @@ namespace ngraph
} }
} }
NodeVector dequantize_linear(const Node& node) OutputVector dequantize_linear(const Node& node)
{ {
const NodeVector inputs{node.get_ng_inputs()}; const OutputVector inputs{node.get_ng_inputs()};
NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3, NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3,
"The DequantizeLinear op expects 2 required and one optional " "The DequantizeLinear op expects 2 required and one optional "
@ -183,7 +183,7 @@ namespace ngraph
auto scale = inputs[1]; auto scale = inputs[1];
auto zero_point = get_zero_point(inputs); auto zero_point = get_zero_point(inputs);
const auto x_shape = x->get_output_partial_shape(0); const auto x_shape = x.get_partial_shape();
NGRAPH_CHECK(x_shape.rank().is_static(), NGRAPH_CHECK(x_shape.rank().is_static(),
"Rank of the input data tensor has to be known (static)."); "Rank of the input data tensor has to be known (static).");

View File

@ -27,13 +27,13 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector dequantize_linear(const Node& node); OutputVector dequantize_linear(const Node& node);
} // namespace set_1 } // namespace set_1
namespace set_13 namespace set_13
{ {
NodeVector dequantize_linear(const Node& node); OutputVector dequantize_linear(const Node& node);
} }
} // namespace op } // namespace op

View File

@ -32,7 +32,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector div(const Node& node) inline OutputVector div(const Node& node)
{ {
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0); const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1); Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
@ -50,7 +50,7 @@ namespace ngraph
namespace set_7 namespace set_7
{ {
inline NodeVector div(const Node& node) inline OutputVector div(const Node& node)
{ {
return {std::make_shared<default_opset::Divide>(node.get_ng_inputs().at(0), return {std::make_shared<default_opset::Divide>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))}; node.get_ng_inputs().at(1))};

View File

@ -30,11 +30,12 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector dropout(const Node& node) inline OutputVector dropout(const Node& node)
{ {
// First value is actual output of Dropout, // First value is actual output of Dropout,
// the second one is just a placeholder for optional trailing output. // the second one is just a placeholder for optional trailing output.
return {node.get_ng_inputs().at(0), std::make_shared<NullNode>()}; return {node.get_ng_inputs().at(0).get_node_shared_ptr(),
std::make_shared<NullNode>()};
} }
} // namespace set_1 } // namespace set_1

View File

@ -28,12 +28,12 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector elu(const Node& node) OutputVector elu(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1); double alpha = node.get_attribute_value<double>("alpha", 1);
return NodeVector{std::make_shared<default_opset::Elu>(data, alpha)}; return OutputVector{std::make_shared<default_opset::Elu>(data, alpha)};
} }
} // namespace set_1 } // namespace set_1

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector elu(const Node& node); OutputVector elu(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector equal(const Node& node) inline OutputVector equal(const Node& node)
{ {
return {std::make_shared<default_opset::Equal>(node.get_ng_inputs().at(0), return {std::make_shared<default_opset::Equal>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))}; node.get_ng_inputs().at(1))};

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector erf(const Node& node) inline OutputVector erf(const Node& node)
{ {
return {std::make_shared<default_opset::Erf>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Erf>(node.get_ng_inputs().at(0))};
} }

View File

@ -30,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector exp(const Node& node) inline OutputVector exp(const Node& node)
{ {
return {std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0))};
} }

View File

@ -30,10 +30,10 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector expand(const Node& node) OutputVector expand(const Node& node)
{ {
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)}; const Output<ngraph::Node> data{node.get_ng_inputs().at(0)};
const std::shared_ptr<ngraph::Node> shape{node.get_ng_inputs().at(1)}; const Output<ngraph::Node> shape{node.get_ng_inputs().at(1)};
return {std::make_shared<default_opset::Broadcast>( return {std::make_shared<default_opset::Broadcast>(
data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)}; data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)};

View File

@ -29,7 +29,7 @@ namespace ngraph
// Expand operator has been available since version 8 of the default ONNX operator set. // Expand operator has been available since version 8 of the default ONNX operator set.
// Currently, Expand is assigned to version 1 due to temporary reason. // Currently, Expand is assigned to version 1 due to temporary reason.
{ {
NodeVector expand(const Node& node); OutputVector expand(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -28,10 +28,10 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector eye_like(const Node& node) OutputVector eye_like(const Node& node)
{ {
const auto input = node.get_ng_inputs().at(0); const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape(); const auto& input_shape = input.get_shape();
std::int64_t dtype; std::int64_t dtype;
element::Type target_type; element::Type target_type;
@ -44,7 +44,7 @@ namespace ngraph
} }
else else
{ {
target_type = input->get_element_type(); target_type = input.get_element_type();
} }
CHECK_VALID_NODE(node, CHECK_VALID_NODE(node,

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector eye_like(const Node& node); OutputVector eye_like(const Node& node);
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector fake_quantize(const onnx_import::Node& node) OutputVector fake_quantize(const onnx_import::Node& node)
{ {
const auto inputs = node.get_ng_inputs(); const auto inputs = node.get_ng_inputs();
const auto X = inputs.at(0); const auto X = inputs.at(0);

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector fake_quantize(const Node& node); OutputVector fake_quantize(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -29,12 +29,12 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector flatten(const Node& node) OutputVector flatten(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; OutputVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
const auto data_rank = data->get_output_partial_shape(0).rank(); const auto data_rank = data.get_partial_shape().rank();
if (data_rank.is_static()) if (data_rank.is_static())
{ {

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector flatten(const Node& node); OutputVector flatten(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -31,7 +31,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector floor(const Node& node) inline OutputVector floor(const Node& node)
{ {
return {std::make_shared<default_opset::Floor>(node.get_ng_inputs().at(0))}; return {std::make_shared<default_opset::Floor>(node.get_ng_inputs().at(0))};
} }

View File

@ -31,14 +31,14 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector gather(const Node& node) inline OutputVector gather(const Node& node)
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; OutputVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0); auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1); auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0); auto axis = node.get_attribute_value<int64_t>("axis", 0);
const auto valid_axis = ngraph::normalize_axis( const auto valid_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_output_partial_shape(0).rank()); node.get_description(), axis, data.get_partial_shape().rank());
return {std::make_shared<default_opset::Gather>( return {std::make_shared<default_opset::Gather>(
data, data,

View File

@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// Disabled in CMakeList
// Update to higher opset required
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
@ -25,9 +28,9 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector gather_nd(const Node& node) OutputVector gather_nd(const Node& node)
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; OutputVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0); auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1); auto indices = ng_inputs.at(1);

View File

@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// Disabled in CMakeList
// Update to higher opset required
#pragma once #pragma once
#include "core/node.hpp" #include "core/node.hpp"
@ -27,7 +30,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector gather_nd(const Node& node); OutputVector gather_nd(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -32,12 +32,12 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector gemm(const Node& node) OutputVector gemm(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; OutputVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> input_a = inputs.at(0); Output<ngraph::Node> input_a = inputs.at(0);
std::shared_ptr<ngraph::Node> input_b = inputs.at(1); Output<ngraph::Node> input_b = inputs.at(1);
std::shared_ptr<ngraph::Node> input_c; Output<ngraph::Node> input_c;
if (inputs.size() == 3) if (inputs.size() == 3)
{ {
@ -46,16 +46,16 @@ namespace ngraph
else else
{ {
input_c = default_opset::Constant::create( input_c = default_opset::Constant::create(
input_b->get_element_type(), ngraph::Shape{}, {0}); input_b.get_element_type(), ngraph::Shape{}, {0});
} }
const auto alpha = node.get_attribute_value<float>("alpha", 1); const auto alpha = node.get_attribute_value<float>("alpha", 1);
const auto beta = node.get_attribute_value<float>("beta", 1); const auto beta = node.get_attribute_value<float>("beta", 1);
const auto alpha_node = default_opset::Constant::create( const auto alpha_node = default_opset::Constant::create(
input_b->get_element_type(), Shape{}, {alpha}); input_b.get_element_type(), Shape{}, {alpha});
const auto beta_node = default_opset::Constant::create( const auto beta_node = default_opset::Constant::create(
input_c->get_element_type(), Shape{}, {beta}); input_c.get_element_type(), Shape{}, {beta});
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0); const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0); const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
@ -85,7 +85,7 @@ namespace ngraph
auto beta_times_input_c = auto beta_times_input_c =
std::make_shared<default_opset::Multiply>(beta_node, input_c); std::make_shared<default_opset::Multiply>(beta_node, input_c);
return NodeVector{ return OutputVector{
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)}; std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
} }
@ -93,12 +93,12 @@ namespace ngraph
namespace set_6 namespace set_6
{ {
NodeVector gemm(const Node& node) OutputVector gemm(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; OutputVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> input_a = inputs.at(0); Output<ngraph::Node> input_a = inputs.at(0);
std::shared_ptr<ngraph::Node> input_b = inputs.at(1); Output<ngraph::Node> input_b = inputs.at(1);
std::shared_ptr<ngraph::Node> input_c; Output<ngraph::Node> input_c;
if (inputs.size() == 3) if (inputs.size() == 3)
{ {
@ -107,16 +107,16 @@ namespace ngraph
else else
{ {
input_c = default_opset::Constant::create( input_c = default_opset::Constant::create(
input_b->get_element_type(), ngraph::Shape{}, {0}); input_b.get_element_type(), ngraph::Shape{}, {0});
} }
const auto alpha = node.get_attribute_value<float>("alpha", 1); const auto alpha = node.get_attribute_value<float>("alpha", 1);
const auto beta = node.get_attribute_value<float>("beta", 1); const auto beta = node.get_attribute_value<float>("beta", 1);
const auto alpha_node = default_opset::Constant::create( const auto alpha_node = default_opset::Constant::create(
input_b->get_element_type(), Shape{}, {alpha}); input_b.get_element_type(), Shape{}, {alpha});
const auto beta_node = default_opset::Constant::create( const auto beta_node = default_opset::Constant::create(
input_c->get_element_type(), Shape{}, {beta}); input_c.get_element_type(), Shape{}, {beta});
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0); const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0); const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
@ -133,7 +133,7 @@ namespace ngraph
auto beta_times_input_c = auto beta_times_input_c =
std::make_shared<default_opset::Multiply>(beta_node, input_c); std::make_shared<default_opset::Multiply>(beta_node, input_c);
return NodeVector{ return OutputVector{
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)}; std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
} }

View File

@ -27,13 +27,13 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector gemm(const Node& node); OutputVector gemm(const Node& node);
} // namespace set_1 } // namespace set_1
namespace set_6 namespace set_6
{ {
NodeVector gemm(const Node& node); OutputVector gemm(const Node& node);
} // namespace set_6 } // namespace set_6

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector global_average_pool(const Node& node) OutputVector global_average_pool(const Node& node)
{ {
return pooling::GlobalPoolingFactory(node).make_avg_pool(); return pooling::GlobalPoolingFactory(node).make_avg_pool();
} }

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX /// \return The vector containing Ngraph nodes producing output of ONNX
/// GlobalAveragePool operation. /// GlobalAveragePool operation.
NodeVector global_average_pool(const Node& node); OutputVector global_average_pool(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector global_max_pool(const Node& node) OutputVector global_max_pool(const Node& node)
{ {
return pooling::GlobalPoolingFactory(node).make_max_pool(); return pooling::GlobalPoolingFactory(node).make_max_pool();
} }

View File

@ -33,7 +33,7 @@ namespace ngraph
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX /// \return The vector containing Ngraph nodes producing output of ONNX
/// GlobalMaxPool operation. /// GlobalMaxPool operation.
NodeVector global_max_pool(const Node& node); OutputVector global_max_pool(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -31,7 +31,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector greater(const Node& node) inline OutputVector greater(const Node& node)
{ {
return {std::make_shared<default_opset::Greater>(node.get_ng_inputs().at(0), return {std::make_shared<default_opset::Greater>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))}; node.get_ng_inputs().at(1))};

View File

@ -46,7 +46,7 @@ namespace ngraph
if (linear_before_reset) if (linear_before_reset)
{ {
const auto& ng_inputs = node.get_ng_inputs(); const auto& ng_inputs = node.get_ng_inputs();
const auto el_type = ng_inputs.at(0)->get_output_element_type(0); const auto el_type = ng_inputs.at(0).get_element_type();
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3))) if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{ {
@ -68,18 +68,18 @@ namespace ngraph
// ] // ]
m_map[recurrent::OpInput::B] = m_map[recurrent::OpInput::B] =
std::make_shared<default_opset::Concat>( std::make_shared<default_opset::Concat>(
NodeVector{wr_z_bias, OutputVector{wr_z_bias,
wr_r_bias, wr_r_bias,
split_bias.at(2), split_bias.at(2),
split_bias.at(5)}, split_bias.at(5)},
1); 1);
} }
else else
{ {
const std::size_t hidden_size = const std::size_t hidden_size =
m_map[recurrent::OpInput::R]->get_shape().back(); m_map[recurrent::OpInput::R].get_shape().back();
const std::size_t num_directions = const std::size_t num_directions =
m_map[recurrent::OpInput::W]->get_shape().front(); m_map[recurrent::OpInput::W].get_shape().front();
m_map[recurrent::OpInput::B] = m_map[recurrent::OpInput::B] =
std::make_shared<default_opset::Constant>( std::make_shared<default_opset::Constant>(
@ -110,7 +110,7 @@ namespace ngraph
}; };
} }
NodeVector gru(const Node& node) OutputVector gru(const Node& node)
{ {
constexpr std::size_t gates_count = 3; constexpr std::size_t gates_count = 3;
GRUInputMap input_map{node, gates_count}; GRUInputMap input_map{node, gates_count};

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector gru(const Node& node); OutputVector gru(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -27,17 +27,17 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector hard_sigmoid(const Node& node) OutputVector hard_sigmoid(const Node& node)
{ {
const auto data = node.get_ng_inputs().at(0); const auto data = node.get_ng_inputs().at(0);
const auto alpha = default_opset::Constant::create<double>( const auto alpha = default_opset::Constant::create<double>(
data->get_element_type(), data.get_element_type(),
Shape{}, Shape{},
std::vector<double>{node.get_attribute_value<double>("alpha", 0.2)}); std::vector<double>{node.get_attribute_value<double>("alpha", 0.2)});
const auto beta = default_opset::Constant::create<double>( const auto beta = default_opset::Constant::create<double>(
data->get_element_type(), data.get_element_type(),
Shape{}, Shape{},
std::vector<double>{node.get_attribute_value<double>("beta", 0.5)}); std::vector<double>{node.get_attribute_value<double>("beta", 0.5)});

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector hard_sigmoid(const Node& node); OutputVector hard_sigmoid(const Node& node);
} // namespace set_1 } // namespace set_1

View File

@ -31,10 +31,10 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector hardmax(const Node& node) OutputVector hardmax(const Node& node)
{ {
const auto input = node.get_ng_inputs().at(0); const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_output_partial_shape(0); const auto& input_shape = input.get_partial_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
if (input_shape.rank().is_static()) if (input_shape.rank().is_static())
@ -48,11 +48,10 @@ namespace ngraph
const auto coerced_tensor_shape = const auto coerced_tensor_shape =
std::make_shared<default_opset::ShapeOf>(coerced_tensor); std::make_shared<default_opset::ShapeOf>(coerced_tensor);
std::shared_ptr<ngraph::Node> row_size = Output<ngraph::Node> row_size = std::make_shared<default_opset::Gather>(
std::make_shared<default_opset::Gather>( coerced_tensor_shape,
coerced_tensor_shape, default_opset::Constant::create(element::i64, {1}, {1}),
default_opset::Constant::create(element::i64, {1}, {1}), default_opset::Constant::create(element::i64, {}, {0}));
default_opset::Constant::create(element::i64, {}, {0}));
row_size = ngraph::onnx_import::reshape::interpret_as_scalar(row_size); row_size = ngraph::onnx_import::reshape::interpret_as_scalar(row_size);
const auto indices_axis = 1; const auto indices_axis = 1;
@ -70,8 +69,8 @@ namespace ngraph
const auto results = std::make_shared<default_opset::OneHot>( const auto results = std::make_shared<default_opset::OneHot>(
topk->output(1), row_size, on_value, off_value, indices_axis); topk->output(1), row_size, on_value, off_value, indices_axis);
const auto converted_results = std::make_shared<default_opset::Convert>( const auto converted_results =
results, input->get_element_type()); std::make_shared<default_opset::Convert>(results, input.get_element_type());
if (input_shape.is_static()) if (input_shape.is_static())
{ {

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector hardmax(const Node& node); OutputVector hardmax(const Node& node);
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op

View File

@ -30,17 +30,17 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector identity(const Node& node) inline OutputVector identity(const Node& node)
{ {
auto input = node.get_ng_inputs().at(0); auto input = node.get_ng_inputs().at(0);
if (input->get_element_type() == ngraph::element::boolean) if (input.get_element_type() == ngraph::element::boolean)
{ {
const auto logic_zero = const auto logic_zero =
default_opset::Constant::create(ngraph::element::boolean, {}, {false}); default_opset::Constant::create(ngraph::element::boolean, {}, {false});
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)}; return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
} }
const auto zero = const auto zero =
default_opset::Constant::create(input->get_element_type(), {}, {0}); default_opset::Constant::create(input.get_element_type(), {}, {0});
return {std::make_shared<default_opset::Add>(input, zero)}; return {std::make_shared<default_opset::Add>(input, zero)};
} }
} // namespace set_1 } // namespace set_1

View File

@ -25,14 +25,14 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector image_scaler(const Node& node) OutputVector image_scaler(const Node& node)
{ {
const auto inputs = node.get_ng_inputs(); const auto inputs = node.get_ng_inputs();
NGRAPH_CHECK( NGRAPH_CHECK(
inputs.size() == 1, "ImageScaler 1 input tensor. Got: ", inputs.size()); inputs.size() == 1, "ImageScaler 1 input tensor. Got: ", inputs.size());
const auto data = inputs[0]; const auto data = inputs[0];
const auto& data_shape = data->get_output_partial_shape(0); const auto& data_shape = data.get_partial_shape();
NGRAPH_CHECK(data_shape.rank().same_scheme({4}), NGRAPH_CHECK(data_shape.rank().same_scheme({4}),
"ImageScaler expects a 4D tensor with NCHW format. Got: ", "ImageScaler expects a 4D tensor with NCHW format. Got: ",
data_shape); data_shape);

View File

@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector image_scaler(const Node& node); OutputVector image_scaler(const Node& node);
} }
} }
} }

View File

@ -39,7 +39,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector instance_norm(const Node& node) OutputVector instance_norm(const Node& node)
{ {
Output<ngraph::Node> data(node.get_ng_inputs().at(0)); Output<ngraph::Node> data(node.get_ng_inputs().at(0));
Output<ngraph::Node> scale(node.get_ng_inputs().at(1)); Output<ngraph::Node> scale(node.get_ng_inputs().at(1));

Some files were not shown because too many files have changed in this diff Show More