NodeVector -> OutputVector replacement (#1272)
This commit is contained in:
parent
dec7df17ed
commit
f34511642a
@ -23,8 +23,8 @@ void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto last_node = batch_to_space->decompose_op()[0];
|
auto last_node = batch_to_space->decompose_op()[0];
|
||||||
last_node->set_friendly_name(batch_to_space->get_friendly_name());
|
last_node.get_node()->set_friendly_name(batch_to_space->get_friendly_name());
|
||||||
ngraph::replace_node(batch_to_space, last_node);
|
ngraph::replace_node(batch_to_space, last_node.get_node_shared_ptr());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -23,8 +23,8 @@ void ngraph::pass::ConvertSpaceToBatch::convert_space_to_batch() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto last_node = space_to_batch->decompose_op()[0];
|
auto last_node = space_to_batch->decompose_op()[0];
|
||||||
last_node->set_friendly_name(space_to_batch->get_friendly_name());
|
last_node.get_node()->set_friendly_name(space_to_batch->get_friendly_name());
|
||||||
ngraph::replace_node(space_to_batch, last_node);
|
ngraph::replace_node(space_to_batch, last_node.get_node_shared_ptr());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ Output<Node> builder::MatmulFactory::get_right()
|
|||||||
return m_inputs.at(1);
|
return m_inputs.at(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector builder::MatmulFactory::make_matmul_op()
|
OutputVector builder::MatmulFactory::make_matmul_op()
|
||||||
{
|
{
|
||||||
auto left = get_left();
|
auto left = get_left();
|
||||||
auto right = get_right();
|
auto right = get_right();
|
||||||
|
@ -40,8 +40,8 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Create a sub-graph representing an ONNX MatMul operation.
|
/// \brief Create a sub-graph representing an ONNX MatMul operation.
|
||||||
///
|
///
|
||||||
/// \return NodeVector containing the sub-graph output node.
|
/// \return OutputVector containing the sub-graph output node.
|
||||||
virtual NodeVector make_matmul_op();
|
virtual OutputVector make_matmul_op();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \return Output representing the left operand.
|
/// \return Output representing the left operand.
|
||||||
|
@ -172,8 +172,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void check_concat(const OutputVector& args,
|
||||||
check_concat(const NodeVector& args, const NodeVector& mins, const NodeVector& maxs)
|
const OutputVector& mins,
|
||||||
|
const OutputVector& maxs)
|
||||||
{
|
{
|
||||||
auto size = args.size();
|
auto size = args.size();
|
||||||
if (size != mins.size() || size != maxs.size())
|
if (size != mins.size() || size != maxs.size())
|
||||||
@ -184,17 +185,17 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
auto min = mins[i];
|
auto min = mins[i];
|
||||||
auto max = maxs[i];
|
auto max = maxs[i];
|
||||||
auto type = min->get_element_type();
|
auto type = min.get_element_type();
|
||||||
if (type != max->get_element_type())
|
if (type != max.get_element_type())
|
||||||
{
|
{
|
||||||
throw ngraph_error("check_concat: min and max must have same type");
|
throw ngraph_error("check_concat: min and max must have same type");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (min->get_shape() != Shape{1} || max->get_shape() != Shape{1})
|
if (min.get_shape() != Shape{1} || max.get_shape() != Shape{1})
|
||||||
{
|
{
|
||||||
throw ngraph_error("check_concat: min/max shape not Shape{1}: " +
|
throw ngraph_error("check_concat: min/max shape not Shape{1}: " +
|
||||||
vector_to_string(min->get_shape()) +
|
vector_to_string(min.get_shape()) +
|
||||||
vector_to_string(max->get_shape()));
|
vector_to_string(max.get_shape()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,9 +64,9 @@ namespace ngraph
|
|||||||
const ngraph::element::Type& output_type,
|
const ngraph::element::Type& output_type,
|
||||||
const bool requantize = true);
|
const bool requantize = true);
|
||||||
|
|
||||||
void check_concat(const NodeVector& args,
|
void check_concat(const OutputVector& args,
|
||||||
const NodeVector& mins,
|
const OutputVector& mins,
|
||||||
const NodeVector& maxs);
|
const OutputVector& maxs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,28 +25,28 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace builder
|
namespace builder
|
||||||
{
|
{
|
||||||
shared_ptr<Node> QuantizedConcatBuilder(const NodeVector& args,
|
shared_ptr<Node> QuantizedConcatBuilder(const OutputVector& args,
|
||||||
size_t concatenation_axis,
|
size_t concatenation_axis,
|
||||||
const NodeVector& mins,
|
const OutputVector& mins,
|
||||||
const NodeVector& maxs)
|
const OutputVector& maxs)
|
||||||
{
|
{
|
||||||
quantization_utils::check_concat(args, mins, maxs);
|
quantization_utils::check_concat(args, mins, maxs);
|
||||||
auto quant_type = args[0]->get_element_type();
|
auto quant_type = args[0].get_element_type();
|
||||||
|
|
||||||
// output scale
|
// output scale
|
||||||
auto min = make_shared<op::Min>(make_shared<op::Concat>(mins, 0), ngraph::AxisSet{0});
|
auto min = make_shared<op::Min>(make_shared<op::Concat>(mins, 0), ngraph::AxisSet{0});
|
||||||
auto max = make_shared<op::Max>(make_shared<op::Concat>(maxs, 0), ngraph::AxisSet{0});
|
auto max = make_shared<op::Max>(make_shared<op::Concat>(maxs, 0), ngraph::AxisSet{0});
|
||||||
auto out_scale = quantization_utils::get_scale(min, max, quant_type);
|
auto out_scale = quantization_utils::get_scale(min, max, quant_type);
|
||||||
|
|
||||||
NodeVector rescaled_args(args.size());
|
OutputVector rescaled_args(args.size());
|
||||||
for (size_t i = 0; i < args.size(); ++i)
|
for (size_t i = 0; i < args.size(); ++i)
|
||||||
{
|
{
|
||||||
auto q_type = args[i]->get_element_type();
|
auto q_type = args[i].get_element_type();
|
||||||
auto in_scale = make_shared<ngraph::op::Reshape>(
|
auto in_scale = make_shared<ngraph::op::Reshape>(
|
||||||
quantization_utils::get_scale(mins[i], maxs[i], q_type),
|
quantization_utils::get_scale(mins[i], maxs[i], q_type),
|
||||||
AxisVector{0},
|
AxisVector{0},
|
||||||
Shape{});
|
Shape{});
|
||||||
auto zero = make_constant(q_type, in_scale->get_shape(), 0);
|
auto zero = make_constant(q_type, in_scale->get_output_shape(0), 0);
|
||||||
|
|
||||||
rescaled_args[i] =
|
rescaled_args[i] =
|
||||||
make_shared<op::Dequantize>(args[i], in_scale, zero, element::f32, AxisSet{});
|
make_shared<op::Dequantize>(args[i], in_scale, zero, element::f32, AxisSet{});
|
||||||
@ -58,7 +58,7 @@ namespace ngraph
|
|||||||
AxisSet{},
|
AxisSet{},
|
||||||
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
|
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
|
||||||
}
|
}
|
||||||
OutputVector base = as_output_vector(args);
|
OutputVector base = args;
|
||||||
for (auto node : mins)
|
for (auto node : mins)
|
||||||
{
|
{
|
||||||
base.push_back(node);
|
base.push_back(node);
|
||||||
|
@ -32,9 +32,9 @@ namespace ngraph
|
|||||||
namespace builder
|
namespace builder
|
||||||
{
|
{
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::shared_ptr<Node> QuantizedConcatBuilder(const NodeVector& args,
|
std::shared_ptr<Node> QuantizedConcatBuilder(const OutputVector& args,
|
||||||
size_t concatenation_axis,
|
size_t concatenation_axis,
|
||||||
const NodeVector& mins,
|
const OutputVector& mins,
|
||||||
const NodeVector& maxs);
|
const OutputVector& maxs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,37 +47,13 @@ namespace
|
|||||||
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
|
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
|
||||||
->add_provenance_group_members_above({output}));
|
->add_provenance_group_members_above({output}));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Return the outputs of the node as vector.
|
|
||||||
///
|
|
||||||
/// \param[in] node Node with multiple outputs.
|
|
||||||
///
|
|
||||||
/// \return Vector of outputs of input node.
|
|
||||||
NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
|
|
||||||
{
|
|
||||||
const auto outputs_number = node->get_output_size();
|
|
||||||
ngraph::NodeVector outputs(outputs_number);
|
|
||||||
for (int i = 0; i < outputs_number; ++i)
|
|
||||||
{
|
|
||||||
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
|
|
||||||
{
|
|
||||||
outputs[i] = node->get_output_as_single_output_node(i);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
outputs[i] = std::make_shared<op::GetOutputElement>(node, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return outputs;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector builder::split(const Output<ngraph::Node>& value,
|
OutputVector
|
||||||
const std::vector<size_t>& length_parts,
|
builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
|
||||||
size_t axis)
|
|
||||||
{
|
{
|
||||||
size_t start_index{0};
|
size_t start_index{0};
|
||||||
NodeVector outputs;
|
OutputVector outputs;
|
||||||
for (const auto& length_part : length_parts)
|
for (const auto& length_part : length_parts)
|
||||||
{
|
{
|
||||||
size_t end_index{start_index + length_part};
|
size_t end_index{start_index + length_part};
|
||||||
@ -87,7 +63,7 @@ NodeVector builder::split(const Output<ngraph::Node>& value,
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
|
OutputVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
|
||||||
{
|
{
|
||||||
size_t axis_to_split{static_cast<size_t>(axis)};
|
size_t axis_to_split{static_cast<size_t>(axis)};
|
||||||
if (axis < 0)
|
if (axis < 0)
|
||||||
@ -100,9 +76,9 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
|
|||||||
return split(value, length_parts, axis_to_split);
|
return split(value, length_parts, axis_to_split);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector builder::opset1::split(const Output<Node>& value,
|
OutputVector builder::opset1::split(const Output<Node>& value,
|
||||||
const std::vector<size_t>& split_lengths,
|
const std::vector<size_t>& split_lengths,
|
||||||
int64_t axis)
|
int64_t axis)
|
||||||
{
|
{
|
||||||
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
|
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
|
||||||
const auto split_lengths_node =
|
const auto split_lengths_node =
|
||||||
@ -110,13 +86,13 @@ NodeVector builder::opset1::split(const Output<Node>& value,
|
|||||||
const auto variadic_split =
|
const auto variadic_split =
|
||||||
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
|
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
|
||||||
|
|
||||||
return get_outputs(variadic_split);
|
return variadic_split->outputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
|
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
|
||||||
{
|
{
|
||||||
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
|
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
|
||||||
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
|
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
|
||||||
|
|
||||||
return get_outputs(split);
|
return split->outputs();
|
||||||
}
|
}
|
||||||
|
@ -31,9 +31,9 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing multiple nodes we split input node into.
|
/// \return The vector containing multiple nodes we split input node into.
|
||||||
///
|
///
|
||||||
NodeVector split(const Output<Node>& value,
|
OutputVector split(const Output<Node>& value,
|
||||||
const std::vector<size_t>& length_parts,
|
const std::vector<size_t>& length_parts,
|
||||||
size_t axis = 0);
|
size_t axis = 0);
|
||||||
|
|
||||||
/// \brief Split node on specified axis into multiple parts.
|
/// \brief Split node on specified axis into multiple parts.
|
||||||
///
|
///
|
||||||
@ -47,9 +47,9 @@ namespace ngraph
|
|||||||
/// indexing). This means that the axis to split on will be counted from
|
/// indexing). This means that the axis to split on will be counted from
|
||||||
/// the back of the tensor (negative values are subtracted from its rank).
|
/// the back of the tensor (negative values are subtracted from its rank).
|
||||||
///
|
///
|
||||||
/// \return The vector containing multiple nodes we split input node into.
|
/// \return The vector containing multiple outputs we split input node into.
|
||||||
///
|
///
|
||||||
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
|
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
|
||||||
|
|
||||||
namespace opset1
|
namespace opset1
|
||||||
{
|
{
|
||||||
@ -63,13 +63,13 @@ namespace ngraph
|
|||||||
/// indexing). This means that the axis to split on will be counted from
|
/// indexing). This means that the axis to split on will be counted from
|
||||||
/// the back of the tensor (negative values are subtracted from its rank).
|
/// the back of the tensor (negative values are subtracted from its rank).
|
||||||
///
|
///
|
||||||
/// \return The vector containing multiple nodes we split input node into.
|
/// \return The vector containing multiple outputs we split input node into.
|
||||||
/// The vector is output of Split:v1 op
|
/// The vector is output of Split:v1 op
|
||||||
///
|
///
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
NodeVector split(const Output<Node>& value,
|
OutputVector split(const Output<Node>& value,
|
||||||
const std::vector<size_t>& split_lengths,
|
const std::vector<size_t>& split_lengths,
|
||||||
int64_t axis = 0);
|
int64_t axis = 0);
|
||||||
|
|
||||||
/// \brief Split value on specified axis into multiple parts.
|
/// \brief Split value on specified axis into multiple parts.
|
||||||
///
|
///
|
||||||
@ -88,7 +88,7 @@ namespace ngraph
|
|||||||
/// The vector is output of VariadicSplit:v1 op
|
/// The vector is output of VariadicSplit:v1 op
|
||||||
///
|
///
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
|
OutputVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
|
||||||
}
|
}
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -157,7 +157,7 @@ namespace ngraph
|
|||||||
m_nodes.emplace_back(node_proto, *this);
|
m_nodes.emplace_back(node_proto, *this);
|
||||||
const Node& node{m_nodes.back()};
|
const Node& node{m_nodes.back()};
|
||||||
|
|
||||||
NodeVector ng_nodes{node.get_ng_nodes()};
|
OutputVector ng_nodes{node.get_ng_nodes()};
|
||||||
// Iterate over the number of outputs for given node in graph.
|
// Iterate over the number of outputs for given node in graph.
|
||||||
// Some of them may be optional and trimmed. See:
|
// Some of them may be optional and trimmed. See:
|
||||||
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
|
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
|
||||||
@ -174,14 +174,14 @@ namespace ngraph
|
|||||||
return m_cache->contains(name);
|
return m_cache->contains(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
|
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
|
||||||
{
|
{
|
||||||
return m_cache->get_node(name);
|
return m_cache->get_node(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector Graph::get_ng_outputs() const
|
OutputVector Graph::get_ng_outputs() const
|
||||||
{
|
{
|
||||||
NodeVector results;
|
OutputVector results;
|
||||||
for (const auto& output : m_graph_proto->output())
|
for (const auto& output : m_graph_proto->output())
|
||||||
{
|
{
|
||||||
results.emplace_back(get_ng_node_from_cache(output.name()));
|
results.emplace_back(get_ng_node_from_cache(output.name()));
|
||||||
@ -189,11 +189,11 @@ namespace ngraph
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector Graph::make_ng_nodes(const Node& onnx_node) const
|
OutputVector Graph::make_ng_nodes(const Node& onnx_node) const
|
||||||
{
|
{
|
||||||
const auto ng_node_factory =
|
const auto ng_node_factory =
|
||||||
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
|
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
|
||||||
NodeVector ng_node_vector;
|
OutputVector ng_node_vector;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
ng_node_vector = ng_node_factory(onnx_node);
|
ng_node_vector = ng_node_factory(onnx_node);
|
||||||
@ -223,7 +223,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Graph::set_friendly_names(const Node& onnx_node,
|
void Graph::set_friendly_names(const Node& onnx_node,
|
||||||
const NodeVector& ng_node_vector) const
|
const OutputVector& ng_node_vector) const
|
||||||
{
|
{
|
||||||
for (int i = 0; i < ng_node_vector.size(); ++i)
|
for (int i = 0; i < ng_node_vector.size(); ++i)
|
||||||
{
|
{
|
||||||
@ -234,7 +234,7 @@ namespace ngraph
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
ng_node_vector[i]->set_friendly_name(onnx_node.output(i));
|
ng_node_vector[i].get_node()->set_friendly_name(onnx_node.output(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -267,7 +267,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Graph::add_provenance_tags(const Node& onnx_node,
|
void Graph::add_provenance_tags(const Node& onnx_node,
|
||||||
const NodeVector& ng_node_vector) const
|
const OutputVector& ng_node_vector) const
|
||||||
{
|
{
|
||||||
if (!ngraph::get_provenance_enabled())
|
if (!ngraph::get_provenance_enabled())
|
||||||
{
|
{
|
||||||
@ -278,9 +278,9 @@ namespace ngraph
|
|||||||
const auto ng_inputs = onnx_node.get_ng_inputs();
|
const auto ng_inputs = onnx_node.get_ng_inputs();
|
||||||
|
|
||||||
ngraph::traverse_nodes(
|
ngraph::traverse_nodes(
|
||||||
ng_node_vector,
|
as_node_vector(ng_node_vector),
|
||||||
[&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); },
|
[&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); },
|
||||||
ng_inputs);
|
as_node_vector(ng_inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
Subgraph::Subgraph(const ONNX_NAMESPACE::GraphProto& proto,
|
Subgraph::Subgraph(const ONNX_NAMESPACE::GraphProto& proto,
|
||||||
|
@ -39,12 +39,12 @@ namespace ngraph
|
|||||||
const std::vector<Node>& get_nodes() const { return m_nodes; }
|
const std::vector<Node>& get_nodes() const { return m_nodes; }
|
||||||
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
|
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
|
||||||
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
|
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
|
||||||
NodeVector get_ng_outputs() const;
|
OutputVector get_ng_outputs() const;
|
||||||
const ParameterVector& get_ng_parameters() const { return m_parameters; }
|
const ParameterVector& get_ng_parameters() const { return m_parameters; }
|
||||||
bool is_node_in_cache(const std::string& name) const;
|
bool is_node_in_cache(const std::string& name) const;
|
||||||
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
|
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
|
||||||
const std::string& get_name() const { return m_graph_proto->name(); }
|
const std::string& get_name() const { return m_graph_proto->name(); }
|
||||||
NodeVector make_ng_nodes(const Node& onnx_node) const;
|
OutputVector make_ng_nodes(const Node& onnx_node) const;
|
||||||
const GraphCache& get_graph_cache() const;
|
const GraphCache& get_graph_cache() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -52,7 +52,8 @@ namespace ngraph
|
|||||||
Model& model,
|
Model& model,
|
||||||
std::unique_ptr<GraphCache>&& cache);
|
std::unique_ptr<GraphCache>&& cache);
|
||||||
|
|
||||||
void set_friendly_names(const Node& onnx_node, const NodeVector& ng_node_vector) const;
|
void set_friendly_names(const Node& onnx_node,
|
||||||
|
const OutputVector& ng_node_vector) const;
|
||||||
|
|
||||||
void add_provenance_tag_to_initializer(
|
void add_provenance_tag_to_initializer(
|
||||||
const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const;
|
const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const;
|
||||||
@ -60,7 +61,8 @@ namespace ngraph
|
|||||||
void add_provenance_tag_to_input(const ValueInfo& input,
|
void add_provenance_tag_to_input(const ValueInfo& input,
|
||||||
std::shared_ptr<ngraph::Node> node) const;
|
std::shared_ptr<ngraph::Node> node) const;
|
||||||
|
|
||||||
void add_provenance_tags(const Node& onnx_node, const NodeVector& ng_node_vector) const;
|
void add_provenance_tags(const Node& onnx_node,
|
||||||
|
const OutputVector& ng_node_vector) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
|
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
|
||||||
|
@ -21,12 +21,12 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace onnx_import
|
namespace onnx_import
|
||||||
{
|
{
|
||||||
void GraphCache::emplace_node(const std::string& name, std::shared_ptr<ngraph::Node>&& node)
|
void GraphCache::emplace_node(const std::string& name, Output<ngraph::Node>&& node)
|
||||||
{
|
{
|
||||||
m_graph_cache_map[name] = std::move(node);
|
m_graph_cache_map[name] = std::move(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> GraphCache::get_node(const std::string& name) const
|
Output<ngraph::Node> GraphCache::get_node(const std::string& name) const
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
@ -52,7 +52,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> SubgraphCache::get_node(const std::string& name) const
|
Output<ngraph::Node> SubgraphCache::get_node(const std::string& name) const
|
||||||
{
|
{
|
||||||
// present in subgraph scope
|
// present in subgraph scope
|
||||||
if (GraphCache::contains(name))
|
if (GraphCache::contains(name))
|
||||||
|
@ -35,7 +35,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param[in] name The name of node added to the cache.
|
/// \param[in] name The name of node added to the cache.
|
||||||
/// \param[in] node The node added to the cache.
|
/// \param[in] node The node added to the cache.
|
||||||
void emplace_node(const std::string& name, std::shared_ptr<ngraph::Node>&& node);
|
void emplace_node(const std::string& name, Output<ngraph::Node>&& node);
|
||||||
|
|
||||||
/// \brief Get the node from the cache
|
/// \brief Get the node from the cache
|
||||||
///
|
///
|
||||||
@ -44,7 +44,7 @@ namespace ngraph
|
|||||||
/// \param[in] name The name of the node.
|
/// \param[in] name The name of the node.
|
||||||
///
|
///
|
||||||
/// \return The node named `name`.
|
/// \return The node named `name`.
|
||||||
virtual std::shared_ptr<ngraph::Node> get_node(const std::string& name) const;
|
virtual Output<ngraph::Node> get_node(const std::string& name) const;
|
||||||
|
|
||||||
/// \brief Return true if the node named `name` exist in the cache.
|
/// \brief Return true if the node named `name` exist in the cache.
|
||||||
///
|
///
|
||||||
@ -54,7 +54,7 @@ namespace ngraph
|
|||||||
virtual bool contains(const std::string& name) const;
|
virtual bool contains(const std::string& name) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::shared_ptr<ngraph::Node>> m_graph_cache_map;
|
std::map<std::string, Output<ngraph::Node>> m_graph_cache_map;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SubgraphCache : public GraphCache
|
class SubgraphCache : public GraphCache
|
||||||
@ -72,7 +72,7 @@ namespace ngraph
|
|||||||
/// \param[in] name The name of the node.
|
/// \param[in] name The name of the node.
|
||||||
///
|
///
|
||||||
/// \return The node named `name` from subgraph (as present) or from parent graph.
|
/// \return The node named `name` from subgraph (as present) or from parent graph.
|
||||||
std::shared_ptr<ngraph::Node> get_node(const std::string& name) const override;
|
Output<ngraph::Node> get_node(const std::string& name) const override;
|
||||||
|
|
||||||
/// \brief Return true if the node named `name` exist in the cache.
|
/// \brief Return true if the node named `name` exist in the cache.
|
||||||
///
|
///
|
||||||
|
@ -40,8 +40,8 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<Attribute>& attributes() const;
|
const std::vector<Attribute>& attributes() const;
|
||||||
NodeVector get_ng_nodes(const Node& node) const;
|
OutputVector get_ng_nodes(const Node& node) const;
|
||||||
NodeVector get_ng_inputs() const;
|
OutputVector get_ng_inputs() const;
|
||||||
|
|
||||||
const std::string& domain() const;
|
const std::string& domain() const;
|
||||||
const std::string& op_type() const;
|
const std::string& op_type() const;
|
||||||
@ -140,14 +140,14 @@ namespace ngraph
|
|||||||
return it->get_subgraph(graph());
|
return it->get_subgraph(graph());
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector Node::Impl::get_ng_nodes(const Node& node) const
|
OutputVector Node::Impl::get_ng_nodes(const Node& node) const
|
||||||
{
|
{
|
||||||
return m_graph->make_ng_nodes(node);
|
return m_graph->make_ng_nodes(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector Node::Impl::get_ng_inputs() const
|
OutputVector Node::Impl::get_ng_inputs() const
|
||||||
{
|
{
|
||||||
NodeVector result;
|
OutputVector result;
|
||||||
for (const auto& name : m_node_proto->input())
|
for (const auto& name : m_node_proto->input())
|
||||||
{
|
{
|
||||||
if (!name.empty())
|
if (!name.empty())
|
||||||
@ -156,7 +156,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
result.push_back(std::make_shared<NullNode>());
|
result.push_back(std::make_shared<NullNode>()->output(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -197,8 +197,8 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector Node::get_ng_inputs() const { return m_pimpl->get_ng_inputs(); }
|
OutputVector Node::get_ng_inputs() const { return m_pimpl->get_ng_inputs(); }
|
||||||
NodeVector Node::get_ng_nodes() const { return m_pimpl->get_ng_nodes(*this); }
|
OutputVector Node::get_ng_nodes() const { return m_pimpl->get_ng_nodes(*this); }
|
||||||
const std::string& Node::domain() const { return m_pimpl->domain(); }
|
const std::string& Node::domain() const { return m_pimpl->domain(); }
|
||||||
const std::string& Node::op_type() const { return m_pimpl->op_type(); }
|
const std::string& Node::op_type() const { return m_pimpl->op_type(); }
|
||||||
const std::string& Node::get_description() const { return m_pimpl->description(); }
|
const std::string& Node::get_description() const { return m_pimpl->description(); }
|
||||||
|
@ -64,8 +64,8 @@ namespace ngraph
|
|||||||
Node& operator=(Node&&) noexcept = delete;
|
Node& operator=(Node&&) noexcept = delete;
|
||||||
Node& operator=(const Node&) = delete;
|
Node& operator=(const Node&) = delete;
|
||||||
|
|
||||||
NodeVector get_ng_inputs() const;
|
OutputVector get_ng_inputs() const;
|
||||||
NodeVector get_ng_nodes() const;
|
OutputVector get_ng_nodes() const;
|
||||||
const std::string& domain() const;
|
const std::string& domain() const;
|
||||||
const std::string& op_type() const;
|
const std::string& op_type() const;
|
||||||
const std::string& get_name() const;
|
const std::string& get_name() const;
|
||||||
|
@ -42,3 +42,8 @@ bool ngraph::op::is_null(const std::shared_ptr<ngraph::Node>& node)
|
|||||||
{
|
{
|
||||||
return is_null(node.get());
|
return is_null(node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ngraph::op::is_null(const Output<ngraph::Node>& output)
|
||||||
|
{
|
||||||
|
return is_null(output.get_node());
|
||||||
|
}
|
||||||
|
@ -29,6 +29,8 @@ namespace ngraph
|
|||||||
bool is_null(const ngraph::Node* node);
|
bool is_null(const ngraph::Node* node);
|
||||||
ONNX_IMPORTER_API
|
ONNX_IMPORTER_API
|
||||||
bool is_null(const std::shared_ptr<ngraph::Node>& node);
|
bool is_null(const std::shared_ptr<ngraph::Node>& node);
|
||||||
|
ONNX_IMPORTER_API
|
||||||
|
bool is_null(const Output<ngraph::Node>& output);
|
||||||
}
|
}
|
||||||
namespace onnx_import
|
namespace onnx_import
|
||||||
{
|
{
|
||||||
|
@ -28,7 +28,7 @@ namespace ngraph
|
|||||||
namespace onnx_import
|
namespace onnx_import
|
||||||
{
|
{
|
||||||
/// \brief Function which transforms single ONNX operator to nGraph sub-graph.
|
/// \brief Function which transforms single ONNX operator to nGraph sub-graph.
|
||||||
using Operator = std::function<NodeVector(const Node&)>;
|
using Operator = std::function<OutputVector(const Node&)>;
|
||||||
|
|
||||||
/// \brief Map which contains ONNX operators accessible by std::string value as a key.
|
/// \brief Map which contains ONNX operators accessible by std::string value as a key.
|
||||||
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
|
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector abs(const Node& node)
|
inline OutputVector abs(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Abs>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Abs>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector acos(const Node& node)
|
inline OutputVector acos(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Acos>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Acos>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector acosh(const Node& node)
|
inline OutputVector acosh(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Acosh>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Acosh>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector add(const Node& node)
|
OutputVector add(const Node& node)
|
||||||
{
|
{
|
||||||
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
|
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
|
||||||
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
|
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
|
||||||
@ -45,7 +45,7 @@ namespace ngraph
|
|||||||
|
|
||||||
namespace set_7
|
namespace set_7
|
||||||
{
|
{
|
||||||
NodeVector add(const Node& node)
|
OutputVector add(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Add>(node.get_ng_inputs().at(0),
|
return {std::make_shared<default_opset::Add>(node.get_ng_inputs().at(0),
|
||||||
node.get_ng_inputs().at(1))};
|
node.get_ng_inputs().at(1))};
|
||||||
|
@ -29,13 +29,13 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector add(const Node& node);
|
OutputVector add(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
namespace set_7
|
namespace set_7
|
||||||
{
|
{
|
||||||
NodeVector add(const Node& node);
|
OutputVector add(const Node& node);
|
||||||
|
|
||||||
} // namespace set_7
|
} // namespace set_7
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector logical_and(const Node& node)
|
inline OutputVector logical_and(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::LogicalAnd>(
|
return {std::make_shared<default_opset::LogicalAnd>(
|
||||||
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
|
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
|
||||||
|
@ -25,7 +25,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector argmax(const Node& node)
|
OutputVector argmax(const Node& node)
|
||||||
{
|
{
|
||||||
const utils::ArgMinMaxFactory arg_factory(node);
|
const utils::ArgMinMaxFactory arg_factory(node);
|
||||||
return {arg_factory.make_arg_max()};
|
return {arg_factory.make_arg_max()};
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing an Ngraph node which produces the output
|
/// \return The vector containing an Ngraph node which produces the output
|
||||||
/// of an ONNX ArgMax operation.
|
/// of an ONNX ArgMax operation.
|
||||||
NodeVector argmax(const Node& node);
|
OutputVector argmax(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector argmin(const Node& node)
|
OutputVector argmin(const Node& node)
|
||||||
{
|
{
|
||||||
const utils::ArgMinMaxFactory arg_factory(node);
|
const utils::ArgMinMaxFactory arg_factory(node);
|
||||||
return {arg_factory.make_arg_min()};
|
return {arg_factory.make_arg_min()};
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing an Ngraph node which produces the output
|
/// \return The vector containing an Ngraph node which produces the output
|
||||||
/// of an ONNX ArgMin operation.
|
/// of an ONNX ArgMin operation.
|
||||||
NodeVector argmin(const Node& node);
|
OutputVector argmin(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector asin(const Node& node)
|
inline OutputVector asin(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Asin>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Asin>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector asinh(const Node& node)
|
inline OutputVector asinh(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Asinh>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Asinh>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector atan(const Node& node)
|
inline OutputVector atan(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Atan>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Atan>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector atanh(const Node& node)
|
inline OutputVector atanh(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Atanh>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Atanh>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector average_pool(const Node& node)
|
OutputVector average_pool(const Node& node)
|
||||||
{
|
{
|
||||||
return pooling::LocalPoolingFactory(node).make_avg_pool();
|
return pooling::LocalPoolingFactory(node).make_avg_pool();
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing Ngraph nodes producing output of ONNX AveragePool
|
/// \return The vector containing Ngraph nodes producing output of ONNX AveragePool
|
||||||
/// operation.
|
/// operation.
|
||||||
NodeVector average_pool(const Node& node);
|
OutputVector average_pool(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -30,14 +30,14 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector batch_norm(const Node& node)
|
OutputVector batch_norm(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector inputs{node.get_ng_inputs()};
|
OutputVector inputs{node.get_ng_inputs()};
|
||||||
auto x = inputs.at(0);
|
auto x = inputs.at(0);
|
||||||
auto scale = inputs.at(1);
|
auto scale = inputs.at(1);
|
||||||
auto bias = inputs.at(2);
|
auto bias = inputs.at(2);
|
||||||
std::shared_ptr<ngraph::Node> mean{nullptr};
|
Output<ngraph::Node> mean;
|
||||||
std::shared_ptr<ngraph::Node> var{nullptr};
|
Output<ngraph::Node> var;
|
||||||
|
|
||||||
std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)};
|
std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)};
|
||||||
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
|
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector batch_norm(const Node& node);
|
OutputVector batch_norm(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cast(const Node& node)
|
OutputVector cast(const Node& node)
|
||||||
{
|
{
|
||||||
auto data = node.get_ng_inputs().at(0);
|
auto data = node.get_ng_inputs().at(0);
|
||||||
int64_t target_type = node.get_attribute_value<int64_t>("to");
|
int64_t target_type = node.get_attribute_value<int64_t>("to");
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cast(const Node& node);
|
OutputVector cast(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector ceil(const Node& node)
|
inline OutputVector ceil(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Ceiling>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Ceiling>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector clip(const Node& node)
|
OutputVector clip(const Node& node)
|
||||||
{
|
{
|
||||||
const auto data = node.get_ng_inputs().at(0);
|
const auto data = node.get_ng_inputs().at(0);
|
||||||
|
|
||||||
@ -47,14 +47,14 @@ namespace ngraph
|
|||||||
|
|
||||||
namespace set_11
|
namespace set_11
|
||||||
{
|
{
|
||||||
NodeVector clip(const Node& node)
|
OutputVector clip(const Node& node)
|
||||||
{
|
{
|
||||||
const NodeVector inputs{node.get_ng_inputs()};
|
const OutputVector inputs{node.get_ng_inputs()};
|
||||||
const std::shared_ptr<ngraph::Node> data = inputs.at(0);
|
const Output<ngraph::Node> data = inputs.at(0);
|
||||||
const element::Type data_type = data->get_element_type();
|
const element::Type data_type = data.get_element_type();
|
||||||
const Shape data_shape = data->get_shape();
|
const Shape data_shape = data.get_shape();
|
||||||
std::shared_ptr<ngraph::Node> min;
|
Output<ngraph::Node> min;
|
||||||
std::shared_ptr<ngraph::Node> max;
|
Output<ngraph::Node> max;
|
||||||
|
|
||||||
// If second input is provided, assign to min input, otherwise set lowest
|
// If second input is provided, assign to min input, otherwise set lowest
|
||||||
// numeric limit of double as min input.
|
// numeric limit of double as min input.
|
||||||
|
@ -27,13 +27,13 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector clip(const Node& node);
|
OutputVector clip(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
namespace set_11
|
namespace set_11
|
||||||
{
|
{
|
||||||
NodeVector clip(const Node& node);
|
OutputVector clip(const Node& node);
|
||||||
|
|
||||||
} // namespace set_11
|
} // namespace set_11
|
||||||
|
|
||||||
|
@ -30,9 +30,9 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector concat(const Node& node)
|
OutputVector concat(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector inputs{node.get_ng_inputs()};
|
OutputVector inputs{node.get_ng_inputs()};
|
||||||
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
|
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
|
||||||
return {std::make_shared<default_opset::Concat>(inputs, axis)};
|
return {std::make_shared<default_opset::Concat>(inputs, axis)};
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector concat(const Node& node);
|
OutputVector concat(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector constant(const onnx_import::Node& node)
|
OutputVector constant(const onnx_import::Node& node)
|
||||||
{
|
{
|
||||||
return {make_constant(node.get_attribute_value<Tensor>("value"))};
|
return {make_constant(node.get_attribute_value<Tensor>("value"))};
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector constant(const Node& node);
|
OutputVector constant(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -28,9 +28,9 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector constant_of_shape(const onnx_import::Node& node)
|
OutputVector constant_of_shape(const onnx_import::Node& node)
|
||||||
{
|
{
|
||||||
std::shared_ptr<ngraph::Node> constant_value;
|
Output<ngraph::Node> constant_value;
|
||||||
if (node.has_attribute("value"))
|
if (node.has_attribute("value"))
|
||||||
{
|
{
|
||||||
auto value_tensor = node.get_attribute_value<Tensor>("value");
|
auto value_tensor = node.get_attribute_value<Tensor>("value");
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector constant_of_shape(const Node& node);
|
OutputVector constant_of_shape(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ namespace ngraph
|
|||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
std::shared_ptr<ngraph::op::Op>
|
std::shared_ptr<ngraph::op::Op>
|
||||||
make_ng_convolution(const std::shared_ptr<ngraph::Node>& data,
|
make_ng_convolution(const Output<ngraph::Node>& data,
|
||||||
const std::shared_ptr<ngraph::Node>& filters,
|
const Output<ngraph::Node>& filters,
|
||||||
const ngraph::Strides& strides,
|
const ngraph::Strides& strides,
|
||||||
const ngraph::Strides& dilations,
|
const ngraph::Strides& dilations,
|
||||||
const ngraph::CoordinateDiff& padding_below,
|
const ngraph::CoordinateDiff& padding_below,
|
||||||
@ -49,7 +49,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
if (groups > 1)
|
if (groups > 1)
|
||||||
{
|
{
|
||||||
auto filters_shape = filters->get_shape();
|
auto filters_shape = filters.get_shape();
|
||||||
filters_shape.at(0) = filters_shape.at(0) / groups;
|
filters_shape.at(0) = filters_shape.at(0) / groups;
|
||||||
filters_shape.insert(filters_shape.begin(), groups);
|
filters_shape.insert(filters_shape.begin(), groups);
|
||||||
|
|
||||||
@ -77,18 +77,16 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node>
|
std::shared_ptr<ngraph::Node> add_bias(const Output<ngraph::Node>& ng_conv,
|
||||||
add_bias(const std::shared_ptr<ngraph::Node>& ng_conv,
|
const Output<ngraph::Node>& bias)
|
||||||
const std::shared_ptr<ngraph::Node>& bias)
|
|
||||||
{
|
{
|
||||||
const auto rank_of_conv =
|
const auto rank_of_conv = ng_conv.get_partial_shape().rank().get_length();
|
||||||
ng_conv->get_output_partial_shape(0).rank().get_length();
|
|
||||||
|
|
||||||
// reshape the bias node {M} to {1, M, 1, 1, ..., 1}
|
// reshape the bias node {M} to {1, M, 1, 1, ..., 1}
|
||||||
// this is required by the addition operation that needs to be able
|
// this is required by the addition operation that needs to be able
|
||||||
// to broadcast the bias to match the shape of the convolution node
|
// to broadcast the bias to match the shape of the convolution node
|
||||||
std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U);
|
std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U);
|
||||||
reshape_pattern_values[1] = bias->get_shape().front();
|
reshape_pattern_values[1] = bias.get_shape().front();
|
||||||
const auto reshape_pattern =
|
const auto reshape_pattern =
|
||||||
default_opset::Constant::create(element::u64,
|
default_opset::Constant::create(element::u64,
|
||||||
Shape{reshape_pattern_values.size()},
|
Shape{reshape_pattern_values.size()},
|
||||||
@ -101,16 +99,16 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
NodeVector conv(const Node& node)
|
OutputVector conv(const Node& node)
|
||||||
{
|
{
|
||||||
// in the current implementation we assume that the data input rank is static
|
// in the current implementation we assume that the data input rank is static
|
||||||
// and only the 'batch' dimension can be dynamic
|
// and only the 'batch' dimension can be dynamic
|
||||||
const NodeVector& inputs = node.get_ng_inputs();
|
const OutputVector& inputs = node.get_ng_inputs();
|
||||||
const auto data = inputs.at(0);
|
const auto data = inputs.at(0);
|
||||||
const auto filters = inputs.at(1);
|
const auto filters = inputs.at(1);
|
||||||
const auto groups = node.get_attribute_value<int64_t>("group", 1);
|
const auto groups = node.get_attribute_value<int64_t>("group", 1);
|
||||||
|
|
||||||
NGRAPH_CHECK(data->get_output_partial_shape(0).rank().is_static(),
|
NGRAPH_CHECK(data.get_partial_shape().rank().is_static(),
|
||||||
"The input data tensor's rank has to be known (static)");
|
"The input data tensor's rank has to be known (static)");
|
||||||
|
|
||||||
const auto strides = convpool::get_strides(node);
|
const auto strides = convpool::get_strides(node);
|
||||||
@ -137,7 +135,7 @@ namespace ngraph
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
const auto bias = inputs.at(2);
|
const auto bias = inputs.at(2);
|
||||||
const auto bias_ps = bias->get_output_partial_shape(0);
|
const auto bias_ps = bias.get_partial_shape();
|
||||||
|
|
||||||
NGRAPH_CHECK(bias_ps.is_static() && is_vector(bias_ps.to_shape()),
|
NGRAPH_CHECK(bias_ps.is_static() && is_vector(bias_ps.to_shape()),
|
||||||
"The bias input needs to be a static 1D vector");
|
"The bias input needs to be a static 1D vector");
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing Ngraph nodes producing output of ONNX convolution
|
/// \return The vector containing Ngraph nodes producing output of ONNX convolution
|
||||||
/// operation.
|
/// operation.
|
||||||
NodeVector conv(const Node& node);
|
OutputVector conv(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
//*****************************************************************************
|
//*****************************************************************************
|
||||||
|
|
||||||
|
// Disabled in CMakeList
|
||||||
|
// Update to higher opset required
|
||||||
|
|
||||||
#include "conv_integer.hpp"
|
#include "conv_integer.hpp"
|
||||||
#include "exceptions.hpp"
|
#include "exceptions.hpp"
|
||||||
#include "ngraph/builder/make_constant.hpp"
|
#include "ngraph/builder/make_constant.hpp"
|
||||||
@ -31,9 +34,9 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector conv_integer(const Node& node)
|
OutputVector conv_integer(const Node& node)
|
||||||
{
|
{
|
||||||
const NodeVector& inputs = node.get_ng_inputs();
|
const OutputVector& inputs = node.get_ng_inputs();
|
||||||
auto num_inputs = inputs.size();
|
auto num_inputs = inputs.size();
|
||||||
auto input = inputs.at(0);
|
auto input = inputs.at(0);
|
||||||
auto filters = inputs.at(1);
|
auto filters = inputs.at(1);
|
||||||
@ -51,19 +54,18 @@ namespace ngraph
|
|||||||
ngraph::op::PadType auto_pad_type = convpool::get_auto_pad(node);
|
ngraph::op::PadType auto_pad_type = convpool::get_auto_pad(node);
|
||||||
auto& padding_below = paddings.first;
|
auto& padding_below = paddings.first;
|
||||||
auto& padding_above = paddings.second;
|
auto& padding_above = paddings.second;
|
||||||
convpool::calculate_auto_pads(input->get_shape(),
|
convpool::calculate_auto_pads(input.get_shape(),
|
||||||
filters->get_shape(),
|
filters.get_shape(),
|
||||||
window_movement_strides,
|
window_movement_strides,
|
||||||
window_dilation_strides,
|
window_dilation_strides,
|
||||||
auto_pad_type,
|
auto_pad_type,
|
||||||
padding_below,
|
padding_below,
|
||||||
padding_above);
|
padding_above);
|
||||||
|
|
||||||
const Strides default_data_dilation_strides(input->get_shape().size() - 2, 1);
|
const Strides default_data_dilation_strides(input.get_shape().size() - 2, 1);
|
||||||
auto scale_one = make_constant(ngraph::element::f32, Shape{}, 1);
|
auto scale_one = make_constant(ngraph::element::f32, Shape{}, 1);
|
||||||
auto input_zero_point = make_constant(input->get_element_type(), Shape{}, 0);
|
auto input_zero_point = make_constant(input.get_element_type(), Shape{}, 0);
|
||||||
auto filters_zero_point =
|
auto filters_zero_point = make_constant(filters.get_element_type(), Shape{}, 0);
|
||||||
make_constant(filters->get_element_type(), Shape{}, 0);
|
|
||||||
auto output_zero_point = make_constant(ngraph::element::i32, Shape{}, 0);
|
auto output_zero_point = make_constant(ngraph::element::i32, Shape{}, 0);
|
||||||
|
|
||||||
if (num_inputs == 2)
|
if (num_inputs == 2)
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
//*****************************************************************************
|
//*****************************************************************************
|
||||||
|
|
||||||
|
// Disabled in CMakeList
|
||||||
|
// Update to higher opset required
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "core/node.hpp"
|
#include "core/node.hpp"
|
||||||
@ -33,7 +36,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing Ngraph nodes producing output of quantized ONNX
|
/// \return The vector containing Ngraph nodes producing output of quantized ONNX
|
||||||
/// convolution operation.
|
/// convolution operation.
|
||||||
NodeVector conv_integer(const Node& node);
|
OutputVector conv_integer(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -44,9 +44,9 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
std::shared_ptr<ngraph::Node>
|
Output<ngraph::Node>
|
||||||
make_group_conv_backprop(const std::shared_ptr<ngraph::Node>& data,
|
make_group_conv_backprop(const Output<ngraph::Node>& data,
|
||||||
const std::shared_ptr<ngraph::Node>& filters,
|
const Output<ngraph::Node>& filters,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
const Strides& dilations,
|
const Strides& dilations,
|
||||||
const CoordinateDiff& pads_begin,
|
const CoordinateDiff& pads_begin,
|
||||||
@ -83,9 +83,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node>
|
Output<ngraph::Node>
|
||||||
make_conv_backprop(const std::shared_ptr<ngraph::Node>& data,
|
make_conv_backprop(const Output<ngraph::Node>& data,
|
||||||
const std::shared_ptr<ngraph::Node>& filters,
|
const Output<ngraph::Node>& filters,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
const Strides& dilations,
|
const Strides& dilations,
|
||||||
const CoordinateDiff& pads_begin,
|
const CoordinateDiff& pads_begin,
|
||||||
@ -124,10 +124,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node>
|
Output<ngraph::Node> get_reshaped_filters(const Output<ngraph::Node>& filters,
|
||||||
get_reshaped_filters(const std::shared_ptr<ngraph::Node>& filters,
|
const PartialShape& filters_pshape,
|
||||||
const PartialShape& filters_pshape,
|
int64_t groups)
|
||||||
int64_t groups)
|
|
||||||
{
|
{
|
||||||
if (filters_pshape.is_static())
|
if (filters_pshape.is_static())
|
||||||
{
|
{
|
||||||
@ -180,12 +179,11 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node>
|
Output<ngraph::Node> get_prepared_bias(const Output<ngraph::Node>& bias,
|
||||||
get_prepared_bias(const std::shared_ptr<ngraph::Node>& bias,
|
const Output<ngraph::Node>& conv)
|
||||||
const std::shared_ptr<ngraph::Node>& conv)
|
|
||||||
{
|
{
|
||||||
// Prepare bias shape [1, C, 1, 1]
|
// Prepare bias shape [1, C, 1, 1]
|
||||||
const auto& conv_pshape = conv->get_output_partial_shape(0);
|
const auto& conv_pshape = conv.get_partial_shape();
|
||||||
std::shared_ptr<ngraph::Node> bias_shape_node;
|
std::shared_ptr<ngraph::Node> bias_shape_node;
|
||||||
|
|
||||||
if (conv_pshape.rank().is_static() && conv_pshape[1].is_static())
|
if (conv_pshape.rank().is_static() && conv_pshape[1].is_static())
|
||||||
@ -231,9 +229,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector conv_transpose(const Node& node)
|
OutputVector conv_transpose(const Node& node)
|
||||||
{
|
{
|
||||||
const NodeVector& inputs = node.get_ng_inputs();
|
const OutputVector& inputs = node.get_ng_inputs();
|
||||||
|
|
||||||
CHECK_VALID_NODE(node,
|
CHECK_VALID_NODE(node,
|
||||||
inputs.size() == 2 || inputs.size() == 3,
|
inputs.size() == 2 || inputs.size() == 3,
|
||||||
@ -243,8 +241,8 @@ namespace ngraph
|
|||||||
auto data = inputs[0];
|
auto data = inputs[0];
|
||||||
auto filters = inputs[1];
|
auto filters = inputs[1];
|
||||||
|
|
||||||
const auto& data_pshape = data->get_output_partial_shape(0);
|
const auto& data_pshape = data.get_partial_shape();
|
||||||
const auto& filters_pshape = filters->get_output_partial_shape(0);
|
const auto& filters_pshape = filters.get_partial_shape();
|
||||||
|
|
||||||
std::size_t num_spatial_dims = 0;
|
std::size_t num_spatial_dims = 0;
|
||||||
Strides strides, dilations;
|
Strides strides, dilations;
|
||||||
@ -291,7 +289,7 @@ namespace ngraph
|
|||||||
CHECK_VALID_NODE(
|
CHECK_VALID_NODE(
|
||||||
node, groups >= 0, "Incorrect value of 'group' attribute: ", groups);
|
node, groups >= 0, "Incorrect value of 'group' attribute: ", groups);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> conv_node;
|
Output<ngraph::Node> conv_node;
|
||||||
|
|
||||||
// reshape filters to match desired shape:
|
// reshape filters to match desired shape:
|
||||||
// [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
|
// [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing Ngraph nodes producing output of ONNX convolution
|
/// \return The vector containing Ngraph nodes producing output of ONNX convolution
|
||||||
/// operation.
|
/// operation.
|
||||||
NodeVector conv_transpose(const Node& node);
|
OutputVector conv_transpose(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cos(const Node& node)
|
OutputVector cos(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Cos>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Cos>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cos(const Node& node);
|
OutputVector cos(const Node& node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cosh(const Node& node)
|
OutputVector cosh(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Cosh>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Cosh>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cosh(const Node& node);
|
OutputVector cosh(const Node& node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace onnx_import
|
} // namespace onnx_import
|
||||||
|
@ -27,13 +27,13 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cum_sum(const Node& node)
|
OutputVector cum_sum(const Node& node)
|
||||||
{
|
{
|
||||||
auto inputs = node.get_ng_inputs();
|
auto inputs = node.get_ng_inputs();
|
||||||
auto data = inputs.at(0);
|
auto data = inputs.at(0);
|
||||||
bool exclusive = node.get_attribute_value<std::int64_t>("exclusive", 0);
|
bool exclusive = node.get_attribute_value<std::int64_t>("exclusive", 0);
|
||||||
bool reverse = node.get_attribute_value<std::int64_t>("reverse", 0);
|
bool reverse = node.get_attribute_value<std::int64_t>("reverse", 0);
|
||||||
std::shared_ptr<ngraph::Node> axis;
|
Output<ngraph::Node> axis;
|
||||||
|
|
||||||
if (inputs.size() > 1)
|
if (inputs.size() > 1)
|
||||||
{
|
{
|
||||||
@ -44,7 +44,7 @@ namespace ngraph
|
|||||||
axis =
|
axis =
|
||||||
default_opset::Constant::create(element::i64, Shape{}, {0}); // default
|
default_opset::Constant::create(element::i64, Shape{}, {0}); // default
|
||||||
}
|
}
|
||||||
return NodeVector{
|
return OutputVector{
|
||||||
std::make_shared<default_opset::CumSum>(data, axis, exclusive, reverse)};
|
std::make_shared<default_opset::CumSum>(data, axis, exclusive, reverse)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector cum_sum(const Node& node);
|
OutputVector cum_sum(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector depth_to_space(const Node& node)
|
OutputVector depth_to_space(const Node& node)
|
||||||
{
|
{
|
||||||
auto data = node.get_ng_inputs().at(0);
|
auto data = node.get_ng_inputs().at(0);
|
||||||
const auto mode = node.get_attribute_value<std::string>("mode", "DCR");
|
const auto mode = node.get_attribute_value<std::string>("mode", "DCR");
|
||||||
@ -34,7 +34,7 @@ namespace ngraph
|
|||||||
? default_opset::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST
|
? default_opset::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST
|
||||||
: default_opset::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST;
|
: default_opset::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST;
|
||||||
const auto block_size = node.get_attribute_value<std::int64_t>("blocksize");
|
const auto block_size = node.get_attribute_value<std::int64_t>("blocksize");
|
||||||
return NodeVector{std::make_shared<default_opset::DepthToSpace>(
|
return OutputVector{std::make_shared<default_opset::DepthToSpace>(
|
||||||
data, ngraph_mode, block_size)};
|
data, ngraph_mode, block_size)};
|
||||||
}
|
}
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
@ -34,9 +34,9 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \param[in] node The ONNX input node describing operation.
|
/// \param[in] node The ONNX input node describing operation.
|
||||||
///
|
///
|
||||||
/// \return NodeVector containing Tensor with shape:
|
/// \return OutputVector containing Tensor with shape:
|
||||||
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
|
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
|
||||||
NodeVector depth_to_space(const Node& node);
|
OutputVector depth_to_space(const Node& node);
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -36,13 +36,13 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
|
Output<ngraph::Node> get_zero_point(const OutputVector& inputs)
|
||||||
{
|
{
|
||||||
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2]))
|
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2]))
|
||||||
{
|
{
|
||||||
auto zero_point = inputs[2];
|
auto zero_point = inputs[2];
|
||||||
|
|
||||||
if (zero_point->get_element_type() != element::f32)
|
if (zero_point.get_element_type() != element::f32)
|
||||||
{
|
{
|
||||||
zero_point =
|
zero_point =
|
||||||
std::make_shared<default_opset::Convert>(zero_point, element::f32);
|
std::make_shared<default_opset::Convert>(zero_point, element::f32);
|
||||||
@ -58,9 +58,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector dequantize_linear(const Node& node)
|
OutputVector dequantize_linear(const Node& node)
|
||||||
{
|
{
|
||||||
const NodeVector inputs{node.get_ng_inputs()};
|
const OutputVector inputs{node.get_ng_inputs()};
|
||||||
|
|
||||||
NGRAPH_CHECK(
|
NGRAPH_CHECK(
|
||||||
2 <= inputs.size() && inputs.size() <= 3,
|
2 <= inputs.size() && inputs.size() <= 3,
|
||||||
@ -71,8 +71,9 @@ namespace ngraph
|
|||||||
const auto scale = inputs[1];
|
const auto scale = inputs[1];
|
||||||
const auto zero_point = get_zero_point(inputs);
|
const auto zero_point = get_zero_point(inputs);
|
||||||
|
|
||||||
common::validate_scalar_input("Dequantization scale", scale, {element::f32});
|
common::validate_scalar_input(
|
||||||
common::validate_scalar_input("Zero point", zero_point);
|
"Dequantization scale", scale.get_node_shared_ptr(), {element::f32});
|
||||||
|
common::validate_scalar_input("Zero point", zero_point.get_node_shared_ptr());
|
||||||
|
|
||||||
const auto converted_x =
|
const auto converted_x =
|
||||||
std::make_shared<default_opset::Convert>(x, element::f32);
|
std::make_shared<default_opset::Convert>(x, element::f32);
|
||||||
@ -86,11 +87,11 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
void validate_scale(const std::shared_ptr<ngraph::Node> scale,
|
void validate_scale(const Output<ngraph::Node> scale,
|
||||||
const std::shared_ptr<ngraph::Node> x,
|
const Output<ngraph::Node> x,
|
||||||
const int64_t axis)
|
const int64_t axis)
|
||||||
{
|
{
|
||||||
const auto& scale_shape = scale->get_output_partial_shape(0);
|
const auto& scale_shape = scale.get_partial_shape();
|
||||||
NGRAPH_CHECK(scale_shape.rank().get_length() == 0 ||
|
NGRAPH_CHECK(scale_shape.rank().get_length() == 0 ||
|
||||||
scale_shape.rank().get_length() == 1,
|
scale_shape.rank().get_length() == 1,
|
||||||
"Dequantization scale needs to be a scalar or a vector.");
|
"Dequantization scale needs to be a scalar or a vector.");
|
||||||
@ -98,7 +99,7 @@ namespace ngraph
|
|||||||
if (scale_shape.rank().get_length() == 1)
|
if (scale_shape.rank().get_length() == 1)
|
||||||
{
|
{
|
||||||
const auto& scale_dim = scale_shape[0];
|
const auto& scale_dim = scale_shape[0];
|
||||||
const auto& x_shape = x->get_output_partial_shape(0);
|
const auto& x_shape = x.get_partial_shape();
|
||||||
const auto& x_dim_at_axis = x_shape[axis];
|
const auto& x_dim_at_axis = x_shape[axis];
|
||||||
|
|
||||||
NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis),
|
NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis),
|
||||||
@ -111,11 +112,11 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void validate_zero_point(const std::shared_ptr<ngraph::Node> zero_point,
|
void validate_zero_point(const Output<ngraph::Node> zero_point,
|
||||||
const std::shared_ptr<ngraph::Node> x,
|
const Output<ngraph::Node> x,
|
||||||
const int64_t axis)
|
const int64_t axis)
|
||||||
{
|
{
|
||||||
const auto& zero_point_shape = zero_point->get_output_partial_shape(0);
|
const auto& zero_point_shape = zero_point.get_partial_shape();
|
||||||
NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 ||
|
NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 ||
|
||||||
zero_point_shape.rank().get_length() == 1,
|
zero_point_shape.rank().get_length() == 1,
|
||||||
"Zero point needs to be a scalar or a vector.");
|
"Zero point needs to be a scalar or a vector.");
|
||||||
@ -123,7 +124,7 @@ namespace ngraph
|
|||||||
if (zero_point_shape.rank().get_length() == 1)
|
if (zero_point_shape.rank().get_length() == 1)
|
||||||
{
|
{
|
||||||
const auto& zero_point_dim = zero_point_shape[0];
|
const auto& zero_point_dim = zero_point_shape[0];
|
||||||
const auto& x_shape = x->get_output_partial_shape(0);
|
const auto& x_shape = x.get_partial_shape();
|
||||||
const auto& x_dim_at_axis = x_shape[axis];
|
const auto& x_dim_at_axis = x_shape[axis];
|
||||||
|
|
||||||
NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis),
|
NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis),
|
||||||
@ -136,10 +137,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node>
|
std::shared_ptr<ngraph::Node> reshape_input(const Output<ngraph::Node> input,
|
||||||
reshape_input(const std::shared_ptr<ngraph::Node> input,
|
const int64_t axis,
|
||||||
const int64_t axis,
|
const PartialShape& x_shape)
|
||||||
const PartialShape& x_shape)
|
|
||||||
{
|
{
|
||||||
std::vector<int64_t> target_dims;
|
std::vector<int64_t> target_dims;
|
||||||
|
|
||||||
@ -170,9 +170,9 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector dequantize_linear(const Node& node)
|
OutputVector dequantize_linear(const Node& node)
|
||||||
{
|
{
|
||||||
const NodeVector inputs{node.get_ng_inputs()};
|
const OutputVector inputs{node.get_ng_inputs()};
|
||||||
|
|
||||||
NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3,
|
NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3,
|
||||||
"The DequantizeLinear op expects 2 required and one optional "
|
"The DequantizeLinear op expects 2 required and one optional "
|
||||||
@ -183,7 +183,7 @@ namespace ngraph
|
|||||||
auto scale = inputs[1];
|
auto scale = inputs[1];
|
||||||
auto zero_point = get_zero_point(inputs);
|
auto zero_point = get_zero_point(inputs);
|
||||||
|
|
||||||
const auto x_shape = x->get_output_partial_shape(0);
|
const auto x_shape = x.get_partial_shape();
|
||||||
|
|
||||||
NGRAPH_CHECK(x_shape.rank().is_static(),
|
NGRAPH_CHECK(x_shape.rank().is_static(),
|
||||||
"Rank of the input data tensor has to be known (static).");
|
"Rank of the input data tensor has to be known (static).");
|
||||||
|
@ -27,13 +27,13 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector dequantize_linear(const Node& node);
|
OutputVector dequantize_linear(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
namespace set_13
|
namespace set_13
|
||||||
{
|
{
|
||||||
NodeVector dequantize_linear(const Node& node);
|
OutputVector dequantize_linear(const Node& node);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -32,7 +32,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector div(const Node& node)
|
inline OutputVector div(const Node& node)
|
||||||
{
|
{
|
||||||
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
|
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
|
||||||
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
|
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
|
||||||
@ -50,7 +50,7 @@ namespace ngraph
|
|||||||
|
|
||||||
namespace set_7
|
namespace set_7
|
||||||
{
|
{
|
||||||
inline NodeVector div(const Node& node)
|
inline OutputVector div(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Divide>(node.get_ng_inputs().at(0),
|
return {std::make_shared<default_opset::Divide>(node.get_ng_inputs().at(0),
|
||||||
node.get_ng_inputs().at(1))};
|
node.get_ng_inputs().at(1))};
|
||||||
|
@ -30,11 +30,12 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector dropout(const Node& node)
|
inline OutputVector dropout(const Node& node)
|
||||||
{
|
{
|
||||||
// First value is actual output of Dropout,
|
// First value is actual output of Dropout,
|
||||||
// the second one is just a placeholder for optional trailing output.
|
// the second one is just a placeholder for optional trailing output.
|
||||||
return {node.get_ng_inputs().at(0), std::make_shared<NullNode>()};
|
return {node.get_ng_inputs().at(0).get_node_shared_ptr(),
|
||||||
|
std::make_shared<NullNode>()};
|
||||||
}
|
}
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -28,12 +28,12 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector elu(const Node& node)
|
OutputVector elu(const Node& node)
|
||||||
{
|
{
|
||||||
auto data = node.get_ng_inputs().at(0);
|
auto data = node.get_ng_inputs().at(0);
|
||||||
double alpha = node.get_attribute_value<double>("alpha", 1);
|
double alpha = node.get_attribute_value<double>("alpha", 1);
|
||||||
|
|
||||||
return NodeVector{std::make_shared<default_opset::Elu>(data, alpha)};
|
return OutputVector{std::make_shared<default_opset::Elu>(data, alpha)};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector elu(const Node& node);
|
OutputVector elu(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector equal(const Node& node)
|
inline OutputVector equal(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Equal>(node.get_ng_inputs().at(0),
|
return {std::make_shared<default_opset::Equal>(node.get_ng_inputs().at(0),
|
||||||
node.get_ng_inputs().at(1))};
|
node.get_ng_inputs().at(1))};
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector erf(const Node& node)
|
inline OutputVector erf(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Erf>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Erf>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector exp(const Node& node)
|
inline OutputVector exp(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -30,10 +30,10 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector expand(const Node& node)
|
OutputVector expand(const Node& node)
|
||||||
{
|
{
|
||||||
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
|
const Output<ngraph::Node> data{node.get_ng_inputs().at(0)};
|
||||||
const std::shared_ptr<ngraph::Node> shape{node.get_ng_inputs().at(1)};
|
const Output<ngraph::Node> shape{node.get_ng_inputs().at(1)};
|
||||||
|
|
||||||
return {std::make_shared<default_opset::Broadcast>(
|
return {std::make_shared<default_opset::Broadcast>(
|
||||||
data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)};
|
data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)};
|
||||||
|
@ -29,7 +29,7 @@ namespace ngraph
|
|||||||
// Expand operator has been available since version 8 of the default ONNX operator set.
|
// Expand operator has been available since version 8 of the default ONNX operator set.
|
||||||
// Currently, Expand is assigned to version 1 due to temporary reason.
|
// Currently, Expand is assigned to version 1 due to temporary reason.
|
||||||
{
|
{
|
||||||
NodeVector expand(const Node& node);
|
OutputVector expand(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -28,10 +28,10 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector eye_like(const Node& node)
|
OutputVector eye_like(const Node& node)
|
||||||
{
|
{
|
||||||
const auto input = node.get_ng_inputs().at(0);
|
const auto input = node.get_ng_inputs().at(0);
|
||||||
const auto& input_shape = input->get_shape();
|
const auto& input_shape = input.get_shape();
|
||||||
|
|
||||||
std::int64_t dtype;
|
std::int64_t dtype;
|
||||||
element::Type target_type;
|
element::Type target_type;
|
||||||
@ -44,7 +44,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
target_type = input->get_element_type();
|
target_type = input.get_element_type();
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_VALID_NODE(node,
|
CHECK_VALID_NODE(node,
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector eye_like(const Node& node);
|
OutputVector eye_like(const Node& node);
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector fake_quantize(const onnx_import::Node& node)
|
OutputVector fake_quantize(const onnx_import::Node& node)
|
||||||
{
|
{
|
||||||
const auto inputs = node.get_ng_inputs();
|
const auto inputs = node.get_ng_inputs();
|
||||||
const auto X = inputs.at(0);
|
const auto X = inputs.at(0);
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector fake_quantize(const Node& node);
|
OutputVector fake_quantize(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -29,12 +29,12 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector flatten(const Node& node)
|
OutputVector flatten(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector inputs{node.get_ng_inputs()};
|
OutputVector inputs{node.get_ng_inputs()};
|
||||||
auto data = inputs.at(0);
|
auto data = inputs.at(0);
|
||||||
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
|
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
|
||||||
const auto data_rank = data->get_output_partial_shape(0).rank();
|
const auto data_rank = data.get_partial_shape().rank();
|
||||||
|
|
||||||
if (data_rank.is_static())
|
if (data_rank.is_static())
|
||||||
{
|
{
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector flatten(const Node& node);
|
OutputVector flatten(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector floor(const Node& node)
|
inline OutputVector floor(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Floor>(node.get_ng_inputs().at(0))};
|
return {std::make_shared<default_opset::Floor>(node.get_ng_inputs().at(0))};
|
||||||
}
|
}
|
||||||
|
@ -31,14 +31,14 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector gather(const Node& node)
|
inline OutputVector gather(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector ng_inputs{node.get_ng_inputs()};
|
OutputVector ng_inputs{node.get_ng_inputs()};
|
||||||
auto data = ng_inputs.at(0);
|
auto data = ng_inputs.at(0);
|
||||||
auto indices = ng_inputs.at(1);
|
auto indices = ng_inputs.at(1);
|
||||||
auto axis = node.get_attribute_value<int64_t>("axis", 0);
|
auto axis = node.get_attribute_value<int64_t>("axis", 0);
|
||||||
const auto valid_axis = ngraph::normalize_axis(
|
const auto valid_axis = ngraph::normalize_axis(
|
||||||
node.get_description(), axis, data->get_output_partial_shape(0).rank());
|
node.get_description(), axis, data.get_partial_shape().rank());
|
||||||
|
|
||||||
return {std::make_shared<default_opset::Gather>(
|
return {std::make_shared<default_opset::Gather>(
|
||||||
data,
|
data,
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
//*****************************************************************************
|
//*****************************************************************************
|
||||||
|
|
||||||
|
// Disabled in CMakeList
|
||||||
|
// Update to higher opset required
|
||||||
|
|
||||||
#include "ngraph/opsets/opset0.hpp"
|
#include "ngraph/opsets/opset0.hpp"
|
||||||
#include "utils/common.hpp"
|
#include "utils/common.hpp"
|
||||||
|
|
||||||
@ -25,9 +28,9 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector gather_nd(const Node& node)
|
OutputVector gather_nd(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector ng_inputs{node.get_ng_inputs()};
|
OutputVector ng_inputs{node.get_ng_inputs()};
|
||||||
auto data = ng_inputs.at(0);
|
auto data = ng_inputs.at(0);
|
||||||
auto indices = ng_inputs.at(1);
|
auto indices = ng_inputs.at(1);
|
||||||
|
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
//*****************************************************************************
|
//*****************************************************************************
|
||||||
|
|
||||||
|
// Disabled in CMakeList
|
||||||
|
// Update to higher opset required
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "core/node.hpp"
|
#include "core/node.hpp"
|
||||||
@ -27,7 +30,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector gather_nd(const Node& node);
|
OutputVector gather_nd(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -32,12 +32,12 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector gemm(const Node& node)
|
OutputVector gemm(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector inputs{node.get_ng_inputs()};
|
OutputVector inputs{node.get_ng_inputs()};
|
||||||
std::shared_ptr<ngraph::Node> input_a = inputs.at(0);
|
Output<ngraph::Node> input_a = inputs.at(0);
|
||||||
std::shared_ptr<ngraph::Node> input_b = inputs.at(1);
|
Output<ngraph::Node> input_b = inputs.at(1);
|
||||||
std::shared_ptr<ngraph::Node> input_c;
|
Output<ngraph::Node> input_c;
|
||||||
|
|
||||||
if (inputs.size() == 3)
|
if (inputs.size() == 3)
|
||||||
{
|
{
|
||||||
@ -46,16 +46,16 @@ namespace ngraph
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
input_c = default_opset::Constant::create(
|
input_c = default_opset::Constant::create(
|
||||||
input_b->get_element_type(), ngraph::Shape{}, {0});
|
input_b.get_element_type(), ngraph::Shape{}, {0});
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto alpha = node.get_attribute_value<float>("alpha", 1);
|
const auto alpha = node.get_attribute_value<float>("alpha", 1);
|
||||||
const auto beta = node.get_attribute_value<float>("beta", 1);
|
const auto beta = node.get_attribute_value<float>("beta", 1);
|
||||||
|
|
||||||
const auto alpha_node = default_opset::Constant::create(
|
const auto alpha_node = default_opset::Constant::create(
|
||||||
input_b->get_element_type(), Shape{}, {alpha});
|
input_b.get_element_type(), Shape{}, {alpha});
|
||||||
const auto beta_node = default_opset::Constant::create(
|
const auto beta_node = default_opset::Constant::create(
|
||||||
input_c->get_element_type(), Shape{}, {beta});
|
input_c.get_element_type(), Shape{}, {beta});
|
||||||
|
|
||||||
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
|
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
|
||||||
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
|
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
|
||||||
@ -85,7 +85,7 @@ namespace ngraph
|
|||||||
auto beta_times_input_c =
|
auto beta_times_input_c =
|
||||||
std::make_shared<default_opset::Multiply>(beta_node, input_c);
|
std::make_shared<default_opset::Multiply>(beta_node, input_c);
|
||||||
|
|
||||||
return NodeVector{
|
return OutputVector{
|
||||||
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
|
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,12 +93,12 @@ namespace ngraph
|
|||||||
|
|
||||||
namespace set_6
|
namespace set_6
|
||||||
{
|
{
|
||||||
NodeVector gemm(const Node& node)
|
OutputVector gemm(const Node& node)
|
||||||
{
|
{
|
||||||
NodeVector inputs{node.get_ng_inputs()};
|
OutputVector inputs{node.get_ng_inputs()};
|
||||||
std::shared_ptr<ngraph::Node> input_a = inputs.at(0);
|
Output<ngraph::Node> input_a = inputs.at(0);
|
||||||
std::shared_ptr<ngraph::Node> input_b = inputs.at(1);
|
Output<ngraph::Node> input_b = inputs.at(1);
|
||||||
std::shared_ptr<ngraph::Node> input_c;
|
Output<ngraph::Node> input_c;
|
||||||
|
|
||||||
if (inputs.size() == 3)
|
if (inputs.size() == 3)
|
||||||
{
|
{
|
||||||
@ -107,16 +107,16 @@ namespace ngraph
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
input_c = default_opset::Constant::create(
|
input_c = default_opset::Constant::create(
|
||||||
input_b->get_element_type(), ngraph::Shape{}, {0});
|
input_b.get_element_type(), ngraph::Shape{}, {0});
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto alpha = node.get_attribute_value<float>("alpha", 1);
|
const auto alpha = node.get_attribute_value<float>("alpha", 1);
|
||||||
const auto beta = node.get_attribute_value<float>("beta", 1);
|
const auto beta = node.get_attribute_value<float>("beta", 1);
|
||||||
|
|
||||||
const auto alpha_node = default_opset::Constant::create(
|
const auto alpha_node = default_opset::Constant::create(
|
||||||
input_b->get_element_type(), Shape{}, {alpha});
|
input_b.get_element_type(), Shape{}, {alpha});
|
||||||
const auto beta_node = default_opset::Constant::create(
|
const auto beta_node = default_opset::Constant::create(
|
||||||
input_c->get_element_type(), Shape{}, {beta});
|
input_c.get_element_type(), Shape{}, {beta});
|
||||||
|
|
||||||
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
|
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
|
||||||
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
|
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
|
||||||
@ -133,7 +133,7 @@ namespace ngraph
|
|||||||
auto beta_times_input_c =
|
auto beta_times_input_c =
|
||||||
std::make_shared<default_opset::Multiply>(beta_node, input_c);
|
std::make_shared<default_opset::Multiply>(beta_node, input_c);
|
||||||
|
|
||||||
return NodeVector{
|
return OutputVector{
|
||||||
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
|
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,13 +27,13 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector gemm(const Node& node);
|
OutputVector gemm(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
namespace set_6
|
namespace set_6
|
||||||
{
|
{
|
||||||
NodeVector gemm(const Node& node);
|
OutputVector gemm(const Node& node);
|
||||||
|
|
||||||
} // namespace set_6
|
} // namespace set_6
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector global_average_pool(const Node& node)
|
OutputVector global_average_pool(const Node& node)
|
||||||
{
|
{
|
||||||
return pooling::GlobalPoolingFactory(node).make_avg_pool();
|
return pooling::GlobalPoolingFactory(node).make_avg_pool();
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing Ngraph nodes producing output of ONNX
|
/// \return The vector containing Ngraph nodes producing output of ONNX
|
||||||
/// GlobalAveragePool operation.
|
/// GlobalAveragePool operation.
|
||||||
NodeVector global_average_pool(const Node& node);
|
OutputVector global_average_pool(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector global_max_pool(const Node& node)
|
OutputVector global_max_pool(const Node& node)
|
||||||
{
|
{
|
||||||
return pooling::GlobalPoolingFactory(node).make_max_pool();
|
return pooling::GlobalPoolingFactory(node).make_max_pool();
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ namespace ngraph
|
|||||||
///
|
///
|
||||||
/// \return The vector containing Ngraph nodes producing output of ONNX
|
/// \return The vector containing Ngraph nodes producing output of ONNX
|
||||||
/// GlobalMaxPool operation.
|
/// GlobalMaxPool operation.
|
||||||
NodeVector global_max_pool(const Node& node);
|
OutputVector global_max_pool(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector greater(const Node& node)
|
inline OutputVector greater(const Node& node)
|
||||||
{
|
{
|
||||||
return {std::make_shared<default_opset::Greater>(node.get_ng_inputs().at(0),
|
return {std::make_shared<default_opset::Greater>(node.get_ng_inputs().at(0),
|
||||||
node.get_ng_inputs().at(1))};
|
node.get_ng_inputs().at(1))};
|
||||||
|
@ -46,7 +46,7 @@ namespace ngraph
|
|||||||
if (linear_before_reset)
|
if (linear_before_reset)
|
||||||
{
|
{
|
||||||
const auto& ng_inputs = node.get_ng_inputs();
|
const auto& ng_inputs = node.get_ng_inputs();
|
||||||
const auto el_type = ng_inputs.at(0)->get_output_element_type(0);
|
const auto el_type = ng_inputs.at(0).get_element_type();
|
||||||
|
|
||||||
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
|
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
|
||||||
{
|
{
|
||||||
@ -68,18 +68,18 @@ namespace ngraph
|
|||||||
// ]
|
// ]
|
||||||
m_map[recurrent::OpInput::B] =
|
m_map[recurrent::OpInput::B] =
|
||||||
std::make_shared<default_opset::Concat>(
|
std::make_shared<default_opset::Concat>(
|
||||||
NodeVector{wr_z_bias,
|
OutputVector{wr_z_bias,
|
||||||
wr_r_bias,
|
wr_r_bias,
|
||||||
split_bias.at(2),
|
split_bias.at(2),
|
||||||
split_bias.at(5)},
|
split_bias.at(5)},
|
||||||
1);
|
1);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const std::size_t hidden_size =
|
const std::size_t hidden_size =
|
||||||
m_map[recurrent::OpInput::R]->get_shape().back();
|
m_map[recurrent::OpInput::R].get_shape().back();
|
||||||
const std::size_t num_directions =
|
const std::size_t num_directions =
|
||||||
m_map[recurrent::OpInput::W]->get_shape().front();
|
m_map[recurrent::OpInput::W].get_shape().front();
|
||||||
|
|
||||||
m_map[recurrent::OpInput::B] =
|
m_map[recurrent::OpInput::B] =
|
||||||
std::make_shared<default_opset::Constant>(
|
std::make_shared<default_opset::Constant>(
|
||||||
@ -110,7 +110,7 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector gru(const Node& node)
|
OutputVector gru(const Node& node)
|
||||||
{
|
{
|
||||||
constexpr std::size_t gates_count = 3;
|
constexpr std::size_t gates_count = 3;
|
||||||
GRUInputMap input_map{node, gates_count};
|
GRUInputMap input_map{node, gates_count};
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector gru(const Node& node);
|
OutputVector gru(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -27,17 +27,17 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector hard_sigmoid(const Node& node)
|
OutputVector hard_sigmoid(const Node& node)
|
||||||
{
|
{
|
||||||
const auto data = node.get_ng_inputs().at(0);
|
const auto data = node.get_ng_inputs().at(0);
|
||||||
|
|
||||||
const auto alpha = default_opset::Constant::create<double>(
|
const auto alpha = default_opset::Constant::create<double>(
|
||||||
data->get_element_type(),
|
data.get_element_type(),
|
||||||
Shape{},
|
Shape{},
|
||||||
std::vector<double>{node.get_attribute_value<double>("alpha", 0.2)});
|
std::vector<double>{node.get_attribute_value<double>("alpha", 0.2)});
|
||||||
|
|
||||||
const auto beta = default_opset::Constant::create<double>(
|
const auto beta = default_opset::Constant::create<double>(
|
||||||
data->get_element_type(),
|
data.get_element_type(),
|
||||||
Shape{},
|
Shape{},
|
||||||
std::vector<double>{node.get_attribute_value<double>("beta", 0.5)});
|
std::vector<double>{node.get_attribute_value<double>("beta", 0.5)});
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector hard_sigmoid(const Node& node);
|
OutputVector hard_sigmoid(const Node& node);
|
||||||
|
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -31,10 +31,10 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector hardmax(const Node& node)
|
OutputVector hardmax(const Node& node)
|
||||||
{
|
{
|
||||||
const auto input = node.get_ng_inputs().at(0);
|
const auto input = node.get_ng_inputs().at(0);
|
||||||
const auto& input_shape = input->get_output_partial_shape(0);
|
const auto& input_shape = input.get_partial_shape();
|
||||||
|
|
||||||
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
|
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
|
||||||
if (input_shape.rank().is_static())
|
if (input_shape.rank().is_static())
|
||||||
@ -48,11 +48,10 @@ namespace ngraph
|
|||||||
|
|
||||||
const auto coerced_tensor_shape =
|
const auto coerced_tensor_shape =
|
||||||
std::make_shared<default_opset::ShapeOf>(coerced_tensor);
|
std::make_shared<default_opset::ShapeOf>(coerced_tensor);
|
||||||
std::shared_ptr<ngraph::Node> row_size =
|
Output<ngraph::Node> row_size = std::make_shared<default_opset::Gather>(
|
||||||
std::make_shared<default_opset::Gather>(
|
coerced_tensor_shape,
|
||||||
coerced_tensor_shape,
|
default_opset::Constant::create(element::i64, {1}, {1}),
|
||||||
default_opset::Constant::create(element::i64, {1}, {1}),
|
default_opset::Constant::create(element::i64, {}, {0}));
|
||||||
default_opset::Constant::create(element::i64, {}, {0}));
|
|
||||||
row_size = ngraph::onnx_import::reshape::interpret_as_scalar(row_size);
|
row_size = ngraph::onnx_import::reshape::interpret_as_scalar(row_size);
|
||||||
|
|
||||||
const auto indices_axis = 1;
|
const auto indices_axis = 1;
|
||||||
@ -70,8 +69,8 @@ namespace ngraph
|
|||||||
|
|
||||||
const auto results = std::make_shared<default_opset::OneHot>(
|
const auto results = std::make_shared<default_opset::OneHot>(
|
||||||
topk->output(1), row_size, on_value, off_value, indices_axis);
|
topk->output(1), row_size, on_value, off_value, indices_axis);
|
||||||
const auto converted_results = std::make_shared<default_opset::Convert>(
|
const auto converted_results =
|
||||||
results, input->get_element_type());
|
std::make_shared<default_opset::Convert>(results, input.get_element_type());
|
||||||
|
|
||||||
if (input_shape.is_static())
|
if (input_shape.is_static())
|
||||||
{
|
{
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector hardmax(const Node& node);
|
OutputVector hardmax(const Node& node);
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -30,17 +30,17 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
inline NodeVector identity(const Node& node)
|
inline OutputVector identity(const Node& node)
|
||||||
{
|
{
|
||||||
auto input = node.get_ng_inputs().at(0);
|
auto input = node.get_ng_inputs().at(0);
|
||||||
if (input->get_element_type() == ngraph::element::boolean)
|
if (input.get_element_type() == ngraph::element::boolean)
|
||||||
{
|
{
|
||||||
const auto logic_zero =
|
const auto logic_zero =
|
||||||
default_opset::Constant::create(ngraph::element::boolean, {}, {false});
|
default_opset::Constant::create(ngraph::element::boolean, {}, {false});
|
||||||
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
|
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
|
||||||
}
|
}
|
||||||
const auto zero =
|
const auto zero =
|
||||||
default_opset::Constant::create(input->get_element_type(), {}, {0});
|
default_opset::Constant::create(input.get_element_type(), {}, {0});
|
||||||
return {std::make_shared<default_opset::Add>(input, zero)};
|
return {std::make_shared<default_opset::Add>(input, zero)};
|
||||||
}
|
}
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
@ -25,14 +25,14 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector image_scaler(const Node& node)
|
OutputVector image_scaler(const Node& node)
|
||||||
{
|
{
|
||||||
const auto inputs = node.get_ng_inputs();
|
const auto inputs = node.get_ng_inputs();
|
||||||
NGRAPH_CHECK(
|
NGRAPH_CHECK(
|
||||||
inputs.size() == 1, "ImageScaler 1 input tensor. Got: ", inputs.size());
|
inputs.size() == 1, "ImageScaler 1 input tensor. Got: ", inputs.size());
|
||||||
|
|
||||||
const auto data = inputs[0];
|
const auto data = inputs[0];
|
||||||
const auto& data_shape = data->get_output_partial_shape(0);
|
const auto& data_shape = data.get_partial_shape();
|
||||||
NGRAPH_CHECK(data_shape.rank().same_scheme({4}),
|
NGRAPH_CHECK(data_shape.rank().same_scheme({4}),
|
||||||
"ImageScaler expects a 4D tensor with NCHW format. Got: ",
|
"ImageScaler expects a 4D tensor with NCHW format. Got: ",
|
||||||
data_shape);
|
data_shape);
|
||||||
|
@ -27,7 +27,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector image_scaler(const Node& node);
|
OutputVector image_scaler(const Node& node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
NodeVector instance_norm(const Node& node)
|
OutputVector instance_norm(const Node& node)
|
||||||
{
|
{
|
||||||
Output<ngraph::Node> data(node.get_ng_inputs().at(0));
|
Output<ngraph::Node> data(node.get_ng_inputs().at(0));
|
||||||
Output<ngraph::Node> scale(node.get_ng_inputs().at(1));
|
Output<ngraph::Node> scale(node.get_ng_inputs().at(1));
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user