Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Steve Yoo 2021-08-13 13:27:26 +09:00
commit 04fed4c2af
1637 changed files with 84233 additions and 112281 deletions

View File

@ -1,54 +1,28 @@
BasedOnStyle: LLVM BasedOnStyle: Google
IndentWidth: 4 IndentWidth: 4
UseTab: Never UseTab: Never
ColumnLimit: 120
Language: Cpp Language: Cpp
Standard: Cpp11 Standard: Cpp11
AccessModifierOffset: -4 AccessModifierOffset: -4
AlignConsecutiveMacros: true
AlignConsecutiveDeclarations: false AllowAllArgumentsOnNextLine: false
AlignConsecutiveAssignments: false AllowAllConstructorInitializersOnNextLine: false
AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortBlocksOnASingleLine: true AllowShortIfStatementsOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: true AllowShortLambdasOnASingleLine: Empty
AllowShortFunctionsOnASingleLine: Inline AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false BinPackArguments: false
BinPackParameters: false BinPackParameters: false
CommentPragmas: '^#'
BreakBeforeBraces: Allman DerivePointerAlignment: false
BreakConstructorInitializersBeforeComma: true
ColumnLimit: 100
IndentCaseLabels: false
IndentWrappedFunctionNames: true
KeepEmptyLinesAtTheStartOfBlocks: false
NamespaceIndentation: All
PointerAlignment: Left
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SortIncludes: false
ReflowComments: true
IncludeCategories:
- Regex: '^".*'
Priority: 3
- Regex: '^<.*'
Priority: 2
SortIncludes: true
FixNamespaceComments: true FixNamespaceComments: true
IndentCaseLabels: false
IndentPPDirectives: AfterHash
ForEachMacros:
- foreach
- FOREACH_CHILD

View File

@ -12,15 +12,11 @@
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder class numpy_autobroadcast_incompatible_shapes : public ngraph::ngraph_error {
{
class numpy_autobroadcast_incompatible_shapes : public ngraph::ngraph_error
{
public: public:
numpy_autobroadcast_incompatible_shapes(const ngraph::Shape& shape1, numpy_autobroadcast_incompatible_shapes(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
const ngraph::Shape& shape2);
private: private:
const ngraph::Shape m_shape1; const ngraph::Shape m_shape1;
@ -84,8 +80,8 @@ namespace ngraph
/// elements point to ngraph::Node objects whose output values have the same shape. /// elements point to ngraph::Node objects whose output values have the same shape.
/// ///
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes /// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> numpy_broadcast(
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args); const std::pair<Output<Node>, Output<Node>>& args);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix /// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication. /// multiplication.
@ -103,8 +99,7 @@ namespace ngraph
/// ///
/// \return The vector containing both outputs broadcasted. /// \return The vector containing both outputs broadcasted.
/// ///
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, const Output<Node>& right);
const Output<Node>& right);
/// \brief Cast shape of all input nodes for an element-wise operation that requires /// \brief Cast shape of all input nodes for an element-wise operation that requires
/// shape-compatibility /// shape-compatibility
@ -149,23 +144,19 @@ namespace ngraph
/// \return A pair that contains the target shape as its first object and a vector of /// \return A pair that contains the target shape as its first object and a vector of
/// padded input shapes ready to be broadcasted as the second object /// padded input shapes ready to be broadcasted as the second object
/// ///
std::pair<Shape, std::vector<Shape>> std::pair<Shape, std::vector<Shape>> get_numpy_broadcast_shapes(const std::vector<Shape>& input_shapes);
get_numpy_broadcast_shapes(const std::vector<Shape>& input_shapes);
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value, inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
const Shape& new_shape, const Shape& new_shape,
std::size_t start_match_axis) std::size_t start_match_axis) {
{ auto shape_const = op::Constant::create(element::u64, Shape{new_shape.size()}, new_shape);
auto shape_const =
op::Constant::create(element::u64, Shape{new_shape.size()}, new_shape);
return std::make_shared<op::v1::Broadcast>( return std::make_shared<op::v1::Broadcast>(
value, value,
shape_const, shape_const,
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis)); calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
} }
namespace opset1 namespace opset1 {
{
/// ///
/// \brief Broadcast right node to left node's shape using legacy scheme. /// \brief Broadcast right node to left node's shape using legacy scheme.
/// ///
@ -189,8 +180,7 @@ namespace ngraph
/// ///
/// \return The vector with axes indexes mapping . /// \return The vector with axes indexes mapping .
/// ///
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape, std::vector<std::size_t> get_axes_mapping(const Shape& output_shape, const AxisSet& broadcast_axes);
const AxisSet& broadcast_axes);
/// ///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation. /// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
@ -202,9 +192,7 @@ namespace ngraph
/// ///
/// \return Returns the Output object pointing to node with the axes mapping. /// \return Returns the Output object pointing to node with the axes mapping.
/// ///
Output<Node> get_axes_mapping_output(const Shape& output_shape, Output<Node> get_axes_mapping_output(const Shape& output_shape, const Shape& input_shape, std::size_t start_match_axis);
const Shape& input_shape,
std::size_t start_match_axis);
/// ///
/// \brief Creates Node returning the axes mapping for Broadcast operation. /// \brief Creates Node returning the axes mapping for Broadcast operation.
@ -230,16 +218,11 @@ namespace ngraph
/// ///
/// \return The Output object with Node returning axes mapping. /// \return The Output object with Node returning axes mapping.
/// ///
Output<Node> get_axes_mapping_output(const Shape& output_shape, Output<Node> get_axes_mapping_output(const Shape& output_shape, const AxisSet& broadcast_axes);
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node, Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, const AxisSet& broadcast_axes);
const Shape& target_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node, Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, std::size_t start_match_axis);
const Shape& target_shape,
std::size_t start_match_axis);
} // namespace opset1 } // namespace opset1
} // namespace builder } // namespace builder

View File

@ -10,14 +10,10 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/type/float16.hpp" #include "ngraph/type/float16.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder
{
template <class T> template <class T>
std::shared_ptr<Node> std::shared_ptr<Node> make_constant(const element::Type& type, const Shape& shape, const T& num) {
make_constant(const element::Type& type, const Shape& shape, const T& num)
{
std::shared_ptr<Node> val = nullptr; std::shared_ptr<Node> val = nullptr;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
@ -25,15 +21,15 @@ namespace ngraph
# pragma GCC diagnostic error "-Wswitch" # pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum" # pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (type) switch (type) {
{
case element::Type_t::f32: case element::Type_t::f32:
val = std::make_shared<ngraph::op::Constant>( val =
type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)}); std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
break; break;
case element::Type_t::f64: case element::Type_t::f64:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)}); ngraph::Shape{},
std::vector<double>{static_cast<double>(num)});
break; break;
case element::Type_t::f16: case element::Type_t::f16:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
@ -48,36 +44,44 @@ namespace ngraph
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))}); std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
break; break;
case element::Type_t::i64: case element::Type_t::i64:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)}); ngraph::Shape{},
std::vector<int64_t>{static_cast<int64_t>(num)});
break; break;
case element::Type_t::i32: case element::Type_t::i32:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)}); ngraph::Shape{},
std::vector<int32_t>{static_cast<int32_t>(num)});
break; break;
case element::Type_t::i16: case element::Type_t::i16:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)}); ngraph::Shape{},
std::vector<int16_t>{static_cast<int16_t>(num)});
break; break;
case element::Type_t::i8: case element::Type_t::i8:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)}); ngraph::Shape{},
std::vector<int8_t>{static_cast<int8_t>(num)});
break; break;
case element::Type_t::u64: case element::Type_t::u64:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)}); ngraph::Shape{},
std::vector<uint64_t>{static_cast<uint64_t>(num)});
break; break;
case element::Type_t::u32: case element::Type_t::u32:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)}); ngraph::Shape{},
std::vector<uint32_t>{static_cast<uint32_t>(num)});
break; break;
case element::Type_t::u16: case element::Type_t::u16:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)}); ngraph::Shape{},
std::vector<uint16_t>{static_cast<uint16_t>(num)});
break; break;
case element::Type_t::u8: case element::Type_t::u8:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(type,
type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)}); ngraph::Shape{},
std::vector<uint8_t>{static_cast<uint8_t>(num)});
break; break;
case element::Type_t::dynamic: case element::Type_t::dynamic:
throw ngraph_error("make_constant: Unsupported element type 'dynamic'"); throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
@ -96,11 +100,9 @@ namespace ngraph
# pragma GCC diagnostic pop # pragma GCC diagnostic pop
#endif #endif
if (shape.size() > 0) if (shape.size() > 0) {
{
ngraph::AxisSet axes; ngraph::AxisSet axes;
for (size_t i = 0; i < shape.size(); i++) for (size_t i = 0; i < shape.size(); i++) {
{
axes.insert(i); axes.insert(i);
} }
val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr(); val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
@ -119,7 +121,6 @@ namespace ngraph
/// ///
/// \return The Constant node which have expected type, shape and value. /// \return The Constant node which have expected type, shape and value.
/// ///
std::shared_ptr<Node> std::shared_ptr<Node> make_constant_from_double(const element::Type& type, const Shape& shape, double num);
make_constant_from_double(const element::Type& type, const Shape& shape, double num);
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph

View File

@ -10,21 +10,17 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder
{
/// \brief Specifies method of bias application to avoid numerical problems /// \brief Specifies method of bias application to avoid numerical problems
enum class BiasMode enum class BiasMode {
{
// Add bias to intermediate result // Add bias to intermediate result
ADD, ADD,
// Calculate max of intermediate result and bias // Calculate max of intermediate result and bias
MAX MAX
}; };
namespace opset1 namespace opset1 {
{
/// \brief Calculates L-0 norm of input tensor. /// \brief Calculates L-0 norm of input tensor.
/// ///
/// \note The L-0 norm represents the cardinality of elements different /// \note The L-0 norm represents the cardinality of elements different
@ -36,9 +32,7 @@ namespace ngraph
/// ///
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops. /// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
/// ///
std::shared_ptr<Node> l0_norm(const Output<Node>& value, std::shared_ptr<Node> l0_norm(const Output<Node>& value, const Output<Node>& reduction_axes, bool keep_dims = false);
const Output<Node>& reduction_axes,
bool keep_dims = false);
/// \brief Calculates L-1 norm of a value. /// \brief Calculates L-1 norm of a value.
/// ///

View File

@ -7,12 +7,9 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder namespace opset1 {
{
namespace opset1
{
// clang-format off // clang-format off
/// \brief Sum-based Mean of a Tensor. /// \brief Sum-based Mean of a Tensor.
/// ///
@ -36,13 +33,9 @@ namespace ngraph
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. | /// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
// clang-format on // clang-format on
std::shared_ptr<Node> mean(const Output<Node>& node, std::shared_ptr<Node> mean(const Output<Node>& node, const AxisSet& reduction_axes, bool keep_dims = false);
const AxisSet& reduction_axes,
bool keep_dims = false);
std::shared_ptr<Node> mean(const Output<Node>& node, std::shared_ptr<Node> mean(const Output<Node>& node, const Output<Node>& reduction_axes, bool keep_dims = false);
const Output<Node>& reduction_axes,
bool keep_dims = false);
// clang-format off // clang-format off
/// \brief Sum-based Variance of a Tensor. /// \brief Sum-based Variance of a Tensor.

View File

@ -11,12 +11,9 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder namespace opset1 {
{
namespace opset1
{
/// \brief Change shape of a value /// \brief Change shape of a value
/// ///
/// \param[in] value The value to be reshaped. /// \param[in] value The value to be reshaped.
@ -31,8 +28,7 @@ namespace ngraph
/// \param axes_order The permutation of axes. /// \param axes_order The permutation of axes.
/// ///
/// \return Transpose:v1 op. /// \return Transpose:v1 op.
std::shared_ptr<Node> reorder_axes(const Output<Node>& value, std::shared_ptr<Node> reorder_axes(const Output<Node>& value, std::vector<size_t> axes_order = {});
std::vector<size_t> axes_order = {});
/// \brief Return transposed value (with axes in reversed order). /// \brief Return transposed value (with axes in reversed order).
/// ///
@ -66,8 +62,7 @@ namespace ngraph
/// \param[in] axes The vector defining indexes of axes to be removed. /// \param[in] axes The vector defining indexes of axes to be removed.
/// ///
/// \return Reshape:v1 op. /// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value, std::shared_ptr<Node> squeeze(const Output<Node>& value, std::vector<std::size_t> axes = {0});
std::vector<std::size_t> axes = {0});
/// \brief Collapse specified axes into single one. /// \brief Collapse specified axes into single one.
/// ///
@ -79,9 +74,7 @@ namespace ngraph
/// ///
/// \return The node with collapsed specified axes. /// \return The node with collapsed specified axes.
/// ///
std::shared_ptr<Node> collapse(const Output<Node>& value, std::shared_ptr<Node> collapse(const Output<Node>& value, const std::size_t start_axis, const std::size_t end_axis);
const std::size_t start_axis,
const std::size_t end_axis);
} // namespace opset1 } // namespace opset1
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph

View File

@ -7,10 +7,8 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder
{
/// \brief Split value on specified axis into multiple parts. /// \brief Split value on specified axis into multiple parts.
/// ///
/// \param value The value to be split. /// \param value The value to be split.
@ -20,9 +18,7 @@ namespace ngraph
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple nodes we split input node into.
/// ///
NGRAPH_DEPRECATED("This builder was deprecated.") NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value, OutputVector split(const Output<Node>& value, const std::vector<size_t>& length_parts, int64_t axis = 0);
const std::vector<size_t>& length_parts,
int64_t axis = 0);
/// \brief Split node on specified axis into multiple parts. /// \brief Split node on specified axis into multiple parts.
/// ///
@ -41,8 +37,7 @@ namespace ngraph
NGRAPH_DEPRECATED("This builder was deprecated.") NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0); OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
namespace opset1 namespace opset1 {
{
/// \brief Split value on specified axis into multiple parts. /// \brief Split value on specified axis into multiple parts.
/// ///
/// \param value The value to be split. /// \param value The value to be split.
@ -56,9 +51,7 @@ namespace ngraph
/// \return The vector containing multiple outputs we split input node into. /// \return The vector containing multiple outputs we split input node into.
/// The vector is output of Split:v1 op /// The vector is output of Split:v1 op
/// ///
OutputVector split(const Output<Node>& value, OutputVector split(const Output<Node>& value, const std::vector<size_t>& split_lengths, int64_t axis = 0);
const std::vector<size_t>& split_lengths,
int64_t axis = 0);
/// \brief Split value on specified axis into multiple parts. /// \brief Split value on specified axis into multiple parts.
/// ///

View File

@ -18,21 +18,15 @@
using namespace std; using namespace std;
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(const Shape& shape1,
{
numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(
const Shape& shape1, const Shape& shape2)
: ngraph_error(error_str(shape1, shape2))
, m_shape1(shape1)
, m_shape2(shape2)
{
}
string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1,
const Shape& shape2) const Shape& shape2)
{ : ngraph_error(error_str(shape1, shape2)),
m_shape1(shape1),
m_shape2(shape2) {}
string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1, const Shape& shape2) {
ostringstream os; ostringstream os;
os << "Auto-broadcast not possible for these input shapes:" os << "Auto-broadcast not possible for these input shapes:"
<< " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2); << " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
@ -52,8 +46,7 @@ namespace ngraph
/// ///
/// \return Broadcast shape of input shapes. /// \return Broadcast shape of input shapes.
/// ///
static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape) static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape) {
{
Shape result; Shape result;
auto lhs_rank = lhs_shape.size(); auto lhs_rank = lhs_shape.size();
auto rhs_rank = rhs_shape.size(); auto rhs_rank = rhs_shape.size();
@ -64,13 +57,11 @@ namespace ngraph
// left-pad the rhs_shape with ones // left-pad the rhs_shape with ones
rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1); rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
for (size_t index = 0; index < max_rank; ++index) for (size_t index = 0; index < max_rank; ++index) {
{
size_t lhs_dim = lhs_shape.at(index); size_t lhs_dim = lhs_shape.at(index);
size_t rhs_dim = rhs_shape.at(index); size_t rhs_dim = rhs_shape.at(index);
if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) {
{
throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape); throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
} }
@ -80,29 +71,23 @@ namespace ngraph
return result; return result;
}; };
pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes) pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes) {
{ Shape target_shape = accumulate(begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
Shape target_shape = accumulate(
begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
vector<Shape> full_shapes; vector<Shape> full_shapes;
for (const Shape& input : input_shapes) for (const Shape& input : input_shapes) {
{
Shape padded_shape{input}; Shape padded_shape{input};
padded_shape.insert( padded_shape.insert(begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(move(padded_shape)); full_shapes.push_back(move(padded_shape));
} }
return {target_shape, full_shapes}; return {target_shape, full_shapes};
} }
static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values) static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values) {
{
vector<Shape> input_shapes; vector<Shape> input_shapes;
for (const auto& input : values) for (const auto& input : values) {
{
input_shapes.push_back(input.get_shape()); input_shapes.push_back(input.get_shape());
} }
@ -125,12 +110,10 @@ namespace ngraph
/// ///
static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value, static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
const Shape& output_shape, const Shape& output_shape,
const Shape& source_shape) const Shape& source_shape) {
{
shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr(); shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
// If node already has the required shape, return original node // If node already has the required shape, return original node
if (output_shape == value.get_shape()) if (output_shape == value.get_shape()) {
{
return broadcasted_node; return broadcasted_node;
} }
@ -145,29 +128,22 @@ namespace ngraph
// Positions of axes which have length of 1 are needed to calculate broadcast_axes // Positions of axes which have length of 1 are needed to calculate broadcast_axes
// for nGraph broadcast operation. We need to remove ones from source shape // for nGraph broadcast operation. We need to remove ones from source shape
// to avoid broadcasting axis conflict. // to avoid broadcasting axis conflict.
for (size_t index = 0; index < output_shape.size(); ++index) for (size_t index = 0; index < output_shape.size(); ++index) {
{ if (source_shape.at(index) == 1 && output_shape.at(index) != 1) {
if (source_shape.at(index) == 1 && output_shape.at(index) != 1)
{
broadcast_axes.push_back(index); broadcast_axes.push_back(index);
} } else {
else
{
squeezed_shape.push_back(source_shape.at(index)); squeezed_shape.push_back(source_shape.at(index));
} }
} }
if (squeezed_shape != value.get_shape()) if (squeezed_shape != value.get_shape()) {
{
broadcasted_node = builder::opset1::reshape(value, squeezed_shape); broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
} }
if (!broadcast_axes.empty()) if (!broadcast_axes.empty()) {
{ auto shape_const = op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
auto shape_const = broadcasted_node =
op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape); make_shared<op::v1::Broadcast>(broadcasted_node,
broadcasted_node = make_shared<op::v1::Broadcast>(
broadcasted_node,
shape_const, shape_const,
opset1::get_axes_mapping_output(output_shape, broadcast_axes)); opset1::get_axes_mapping_output(output_shape, broadcast_axes));
} }
@ -183,57 +159,45 @@ namespace ngraph
/// ///
/// \return The broadcasted Node. /// \return The broadcasted Node.
/// ///
static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value, static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value, const Shape& output_shape, int64_t axis) {
const Shape& output_shape,
int64_t axis)
{
auto value_shape = value.get_shape(); auto value_shape = value.get_shape();
// If node already has the required shape, return original node // If node already has the required shape, return original node
if (output_shape == value_shape) if (output_shape == value_shape) {
{
return value.get_node_shared_ptr(); return value.get_node_shared_ptr();
} }
if (axis == -1) if (axis == -1) {
{
axis = output_shape.size() - value_shape.size(); axis = output_shape.size() - value_shape.size();
} }
auto trimmed_value_shape = value_shape; auto trimmed_value_shape = value_shape;
while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1) while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1) {
{
trimmed_value_shape.pop_back(); trimmed_value_shape.pop_back();
} }
AxisSet axes; AxisSet axes;
for (int64_t i = 0; i < axis; ++i) for (int64_t i = 0; i < axis; ++i) {
{
axes.insert(static_cast<size_t>(i)); axes.insert(static_cast<size_t>(i));
} }
for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i) for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i) {
{
axes.insert(i); axes.insert(i);
} }
auto trimmed_value = value; auto trimmed_value = value;
if (value_shape != trimmed_value_shape) if (value_shape != trimmed_value_shape) {
{
trimmed_value = builder::opset1::reshape(value, trimmed_value_shape); trimmed_value = builder::opset1::reshape(value, trimmed_value_shape);
} }
auto shape_const = auto shape_const = op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape); auto value_bcast =
auto value_bcast = make_shared<op::v1::Broadcast>( make_shared<op::v1::Broadcast>(trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
return move(value_bcast); return move(value_bcast);
} }
pair<shared_ptr<Node>, shared_ptr<Node>> pair<shared_ptr<Node>, shared_ptr<Node>> numpy_broadcast(const pair<Output<Node>, Output<Node>>& args) {
numpy_broadcast(const pair<Output<Node>, Output<Node>>& args)
{
NGRAPH_CHECK(args.first.get_node()); NGRAPH_CHECK(args.first.get_node());
NGRAPH_CHECK(args.second.get_node()); NGRAPH_CHECK(args.second.get_node());
@ -241,22 +205,17 @@ namespace ngraph
const Shape& arg2_in_shape = args.second.get_shape(); const Shape& arg2_in_shape = args.second.get_shape();
// Handle the trivial case... // Handle the trivial case...
if (arg1_in_shape == arg2_in_shape) if (arg1_in_shape == arg2_in_shape) {
{ return make_pair(args.first.get_node_shared_ptr(), args.second.get_node_shared_ptr());
return make_pair(args.first.get_node_shared_ptr(),
args.second.get_node_shared_ptr());
} }
NodeVector bcasted_outputs = NodeVector bcasted_outputs = as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1)); return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
} }
OutputVector numpy_broadcast_outputs(const OutputVector& values) OutputVector numpy_broadcast_outputs(const OutputVector& values) {
{ if (values.size() <= 1) {
if (values.size() <= 1)
{
return values; return values;
} }
@ -264,37 +223,29 @@ namespace ngraph
auto bcast_shapes = get_numpy_broadcast_shapes(values); auto bcast_shapes = get_numpy_broadcast_shapes(values);
OutputVector broadcasted_inputs; OutputVector broadcasted_inputs;
for (size_t i = 0; i < values.size(); ++i) for (size_t i = 0; i < values.size(); ++i) {
{ broadcasted_inputs.push_back(numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
broadcasted_inputs.push_back(
numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
} }
return broadcasted_inputs; return broadcasted_inputs;
} }
shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape) shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape) {
{
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape}); auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]); return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
} }
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, const Output<Node>& right) {
const Output<Node>& right)
{
const auto& left_shape = left.get_shape(); const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape(); const auto& right_shape = right.get_shape();
// Broadcast only _stack of matrices_ axes. // Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = const auto& numpy_shapes = get_numpy_broadcast_shapes(
get_numpy_broadcast_shapes({Shape{begin(left_shape), next(end(left_shape), -2)}, {Shape{begin(left_shape), next(end(left_shape), -2)}, Shape{begin(right_shape), next(end(right_shape), -2)}});
Shape{begin(right_shape), next(end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes. // Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first; auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first; auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions. // Append the last two axes original dimensions.
left_output_shape.insert(end(left_output_shape), left_output_shape.insert(end(left_output_shape), next(begin(left_shape), left_shape.size() - 2), end(left_shape));
next(begin(left_shape), left_shape.size() - 2),
end(left_shape));
right_output_shape.insert(end(right_output_shape), right_output_shape.insert(end(right_output_shape),
next(begin(right_shape), right_shape.size() - 2), next(begin(right_shape), right_shape.size() - 2),
end(right_shape)); end(right_shape));
@ -302,37 +253,28 @@ namespace ngraph
auto left_full_shape = numpy_shapes.second.at(0); auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1); auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions. // Append the last two axes original dimensions.
left_full_shape.insert(end(left_full_shape), left_full_shape.insert(end(left_full_shape), next(begin(left_shape), left_shape.size() - 2), end(left_shape));
next(begin(left_shape), left_shape.size() - 2), right_full_shape.insert(end(right_full_shape), next(begin(right_shape), right_shape.size() - 2), end(right_shape));
end(left_shape));
right_full_shape.insert(end(right_full_shape),
next(begin(right_shape), right_shape.size() - 2),
end(right_shape));
return {numpy_broadcast_node(left, left_output_shape, left_full_shape), return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
numpy_broadcast_node(right, right_output_shape, right_full_shape)}; numpy_broadcast_node(right, right_output_shape, right_full_shape)};
} }
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis) OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis) {
{ if (inputs.size() <= 1) {
if (inputs.size() <= 1)
{
return inputs; return inputs;
} }
OutputVector broadcasted_inputs{inputs[0]}; OutputVector broadcasted_inputs{inputs[0]};
for (size_t i = 1; i < inputs.size(); ++i) for (size_t i = 1; i < inputs.size(); ++i) {
{ broadcasted_inputs.push_back(broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
} }
return broadcasted_inputs; return broadcasted_inputs;
} }
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape, std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape, const Shape& input_shape,
size_t start_match_axis) size_t start_match_axis) {
{
vector<size_t> axes(output_shape.size() - input_shape.size()); vector<size_t> axes(output_shape.size() - input_shape.size());
// Populate the axes vector with monotonic increasing series from 0 until // Populate the axes vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range: // output_shape_size, excluding values in range:
@ -344,53 +286,41 @@ namespace ngraph
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping); return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
} }
namespace opset1 namespace opset1 {
{
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left, Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right, const Output<Node>& right,
size_t start_match_axis) size_t start_match_axis) {
{
const auto& left_shape = left.get_shape(); const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape(); const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape); bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical) if (dimensions_identical) {
{
return right; return right;
} }
// Prepare new shape of right operand for broadcasting // Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back // Remove dimensions with length=1 from back
auto new_right_shape = right_shape; auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension) for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension) {
{ if (new_right_shape.at(dimension) == 1) {
if (new_right_shape.at(dimension) == 1)
{
new_right_shape.pop_back(); new_right_shape.pop_back();
} } else {
else
{
break; break;
} }
} }
// Find first dimensions at front with length different from 1 // Find first dimensions at front with length different from 1
size_t num_ones = 0; size_t num_ones = 0;
for (size_t dimension : new_right_shape) for (size_t dimension : new_right_shape) {
{ if (dimension == 1) {
if (dimension == 1)
{
++num_ones; ++num_ones;
} } else {
else
{
break; break;
} }
} }
// Remove dimensions with length=1 from front // Remove dimensions with length=1 from front
new_right_shape.erase(begin(new_right_shape), new_right_shape.erase(begin(new_right_shape), next(begin(new_right_shape), num_ones));
next(begin(new_right_shape), num_ones));
auto reshape_right = reshape(right, new_right_shape); auto reshape_right = reshape(right, new_right_shape);
@ -400,14 +330,11 @@ namespace ngraph
return make_broadcast(reshape_right, left_shape, start_match_axis); return make_broadcast(reshape_right, left_shape, start_match_axis);
} }
vector<size_t> get_axes_mapping(const Shape& output_shape, vector<size_t> get_axes_mapping(const Shape& output_shape, const AxisSet& broadcast_axes) {
const AxisSet& broadcast_axes)
{
NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size())); NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
vector<size_t> axes_mapping(output_shape.size()); vector<size_t> axes_mapping(output_shape.size());
iota(axes_mapping.begin(), axes_mapping.end(), 0); iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i) for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i) {
{
axes_mapping.erase(axes_mapping.begin() + *i); axes_mapping.erase(axes_mapping.begin() + *i);
} }
return axes_mapping; return axes_mapping;
@ -415,13 +342,11 @@ namespace ngraph
Output<Node> get_axes_mapping_output(const PartialShape& output_shape, Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const PartialShape& input_shape, const PartialShape& input_shape,
std::size_t start_match_axis) std::size_t start_match_axis) {
{
NGRAPH_CHECK((input_shape.rank().is_static() && output_shape.rank().is_static()), NGRAPH_CHECK((input_shape.rank().is_static() && output_shape.rank().is_static()),
"Tensor's rank has to be static."); "Tensor's rank has to be static.");
NGRAPH_CHECK( NGRAPH_CHECK(
(input_shape.rank().get_length() + static_cast<int64_t>(start_match_axis) <= (input_shape.rank().get_length() + static_cast<int64_t>(start_match_axis) <= output_shape.rank().get_length()),
output_shape.rank().get_length()),
"Unable to figure out axes mapping."); "Unable to figure out axes mapping.");
vector<int64_t> mapping(input_shape.rank().get_length()); vector<int64_t> mapping(input_shape.rank().get_length());
@ -430,54 +355,40 @@ namespace ngraph
return op::Constant::create(element::i64, Shape{mapping.size()}, mapping); return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
} }
Output<Node> get_axes_mapping_output(const Shape& output_shape, Output<Node> get_axes_mapping_output(const Shape& output_shape, const AxisSet& broadcast_axes) {
const AxisSet& broadcast_axes)
{
vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)}; vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping); return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
} }
Output<Node> get_axes_mapping_output(const PartialShape& output_shape, Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const Output<Node>& input_shape, const Output<Node>& input_shape,
std::size_t start_match_axis) std::size_t start_match_axis) {
{
const auto one_node = opset7::Constant::create(element::i64, Shape{}, {1}); const auto one_node = opset7::Constant::create(element::i64, Shape{}, {1});
const auto zero_node = opset7::Constant::create(element::i64, Shape{}, {0}); const auto zero_node = opset7::Constant::create(element::i64, Shape{}, {0});
const auto start_match_axis_node = const auto start_match_axis_node = opset7::Constant::create(element::i64, Shape{}, {start_match_axis});
opset7::Constant::create(element::i64, Shape{}, {start_match_axis}); const auto target_shape_rank_node =
const auto target_shape_rank_node = builder::opset1::reshape( builder::opset1::reshape(std::make_shared<opset7::ShapeOf>(input_shape), Shape{});
std::make_shared<opset7::ShapeOf>(input_shape), Shape{});
const auto range_node = std::make_shared<opset7::Range>( const auto range_node = std::make_shared<opset7::Range>(zero_node, target_shape_rank_node, one_node, element::i64);
zero_node, target_shape_rank_node, one_node, element::i64);
// workaround for GPU plugin type incompatibility // workaround for GPU plugin type incompatibility
const auto range_node_converted = std::make_shared<opset7::Convert>( const auto range_node_converted =
range_node, start_match_axis_node->get_element_type()); std::make_shared<opset7::Convert>(range_node, start_match_axis_node->get_element_type());
// end of workaround // end of workaround
const auto result = const auto result = std::make_shared<opset7::Add>(range_node_converted, start_match_axis_node);
std::make_shared<opset7::Add>(range_node_converted, start_match_axis_node);
return result; return result;
} }
Output<Node> make_broadcast(const Output<Node>& node, Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, const AxisSet& broadcast_axes) {
const Shape& target_shape, return make_shared<op::v1::Broadcast>(node,
const AxisSet& broadcast_axes)
{
return make_shared<op::v1::Broadcast>(
node,
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape), op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, broadcast_axes)); get_axes_mapping_output(target_shape, broadcast_axes));
} }
Output<Node> make_broadcast(const Output<Node>& node, Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, size_t start_match_axis) {
const Shape& target_shape,
size_t start_match_axis)
{
const auto node_shape = std::make_shared<opset7::ShapeOf>(node); const auto node_shape = std::make_shared<opset7::ShapeOf>(node);
return make_shared<op::v1::Broadcast>( return make_shared<op::v1::Broadcast>(node,
node,
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape), op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, node_shape, start_match_axis)); get_axes_mapping_output(target_shape, node_shape, start_match_axis));
} }

View File

@ -4,83 +4,60 @@
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder std::shared_ptr<Node> make_constant_from_double(const element::Type& type, const Shape& shape, double num) {
{ auto ceil_func = [](double x) {
std::shared_ptr<Node> return ceil(x);
make_constant_from_double(const element::Type& type, const Shape& shape, double num) };
{
auto ceil_func = [](double x) { return ceil(x); };
std::shared_ptr<ngraph::Node> result = nullptr; std::shared_ptr<ngraph::Node> result = nullptr;
switch (type) switch (type) {
{ case element::Type_t::i8: {
case element::Type_t::i8: result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int8_t>(num, ceil_func));
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int8_t>(num, ceil_func));
break; break;
} }
case element::Type_t::i16: case element::Type_t::i16: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int16_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int16_t>(num, ceil_func));
break; break;
} }
case element::Type_t::i32: case element::Type_t::i32: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int32_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int32_t>(num, ceil_func));
break; break;
} }
case element::Type_t::i64: case element::Type_t::i64: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int64_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int64_t>(num, ceil_func));
break; break;
} }
case element::Type_t::u8: case element::Type_t::u8: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint8_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint8_t>(num, ceil_func));
break; break;
} }
case element::Type_t::u16: case element::Type_t::u16: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint16_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint16_t>(num, ceil_func));
break; break;
} }
case element::Type_t::u32: case element::Type_t::u32: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint32_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint32_t>(num, ceil_func));
break; break;
} }
case element::Type_t::u64: case element::Type_t::u64: {
{ result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint64_t>(num, ceil_func));
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint64_t>(num, ceil_func));
break; break;
} }
case element::Type_t::f16: case element::Type_t::f16: {
{
result = builder::make_constant(type, shape, static_cast<float16>(num)); result = builder::make_constant(type, shape, static_cast<float16>(num));
break; break;
} }
case element::Type_t::bf16: case element::Type_t::bf16: {
{
result = builder::make_constant(type, shape, static_cast<bfloat16>(num)); result = builder::make_constant(type, shape, static_cast<bfloat16>(num));
break; break;
} }
case element::Type_t::f32: case element::Type_t::f32: {
{
result = builder::make_constant(type, shape, static_cast<float>(num)); result = builder::make_constant(type, shape, static_cast<float>(num));
break; break;
} }
case element::Type_t::f64: case element::Type_t::f64: {
{
result = builder::make_constant(type, shape, num); result = builder::make_constant(type, shape, num);
break; break;
} }

View File

@ -3,6 +3,7 @@
// //
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
@ -18,105 +19,84 @@
using namespace std; using namespace std;
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder namespace detail {
{ namespace opset1 {
namespace detail
{
namespace opset1
{
shared_ptr<Node> lp_norm(const Output<Node>& value, shared_ptr<Node> lp_norm(const Output<Node>& value,
size_t p_norm, size_t p_norm,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
float bias, float bias,
bool keep_dims) bool keep_dims) {
{
// In general "entrywise" lp-norm for matrix `A` is defined as following double // In general "entrywise" lp-norm for matrix `A` is defined as following double
// sum: // sum:
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p} // ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
shared_ptr<Node> abs_values{make_shared<ngraph::opset1::Abs>(value)}; shared_ptr<Node> abs_values{make_shared<ngraph::opset1::Abs>(value)};
shared_ptr<Node> p_node = ngraph::opset1::Constant::create( shared_ptr<Node> p_node = ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {p_norm});
value.get_element_type(), Shape{}, {p_norm});
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes. // Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)}; shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)};
values = values = make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create( shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
values->get_element_type(), Shape{}, {bias})};
values = make_shared<ngraph::opset1::Add>(values, bias_node); values = make_shared<ngraph::opset1::Add>(values, bias_node);
// Get outer part of equation: raise values to 1/p_norm exponent. // Get outer part of equation: raise values to 1/p_norm exponent.
shared_ptr<Node> inv_p_node = ngraph::opset1::Constant::create( shared_ptr<Node> inv_p_node = ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {1.f / p_norm});
values->get_element_type(), Shape{}, {1.f / p_norm});
return {make_shared<ngraph::opset1::Power>(values, inv_p_node) return {make_shared<ngraph::opset1::Power>(values, inv_p_node)->add_provenance_group_members_above({value})};
->add_provenance_group_members_above({value})};
} }
} // namespace opset1 } // namespace opset1
} // namespace detail } // namespace detail
shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value, shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims) bool keep_dims) {
{
// L0 norm returns number of elements different from zero. // L0 norm returns number of elements different from zero.
const shared_ptr<Node> zero_node{ const shared_ptr<Node> zero_node{ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {0.f})};
ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {0.f})};
// Convert bool values to input node data type. // Convert bool values to input node data type.
const shared_ptr<Node> non_zero_values = make_shared<ngraph::opset1::Convert>( const shared_ptr<Node> non_zero_values =
make_shared<ngraph::opset1::NotEqual>(value, zero_node), value.get_element_type()); make_shared<ngraph::opset1::Convert>(make_shared<ngraph::opset1::NotEqual>(value, zero_node),
value.get_element_type());
return make_shared<ngraph::opset1::ReduceSum>( return make_shared<ngraph::opset1::ReduceSum>(non_zero_values, reduction_axes, keep_dims)
non_zero_values, reduction_axes, keep_dims)
->add_provenance_group_members_above({value}); ->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value, shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
float bias, float bias,
bool keep_dims) bool keep_dims) {
{ const shared_ptr<Node> values{
const shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>( make_shared<ngraph::opset1::ReduceSum>(make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
const shared_ptr<Node> bias_node{ const shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
return make_shared<ngraph::opset1::Add>(values, bias_node) return make_shared<ngraph::opset1::Add>(values, bias_node)->add_provenance_group_members_above({value});
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::opset1::l2_norm(const Output<Node>& value, shared_ptr<Node> builder::opset1::l2_norm(const Output<Node>& value,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
float bias, float bias,
BiasMode bias_mode, BiasMode bias_mode,
bool keep_dims) bool keep_dims) {
{ shared_ptr<Node> pow =
shared_ptr<Node> pow = make_shared<ngraph::opset1::Power>( make_shared<ngraph::opset1::Power>(value,
value, make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2)); make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2));
shared_ptr<Node> values{ shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
shared_ptr<Node> bias_node{ shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
shared_ptr<Node> result; shared_ptr<Node> result;
switch (bias_mode) switch (bias_mode) {
{ case BiasMode::MAX: {
case BiasMode::MAX: result = make_shared<ngraph::opset1::Sqrt>(make_shared<ngraph::opset1::Maximum>(values, bias_node));
{
result = make_shared<ngraph::opset1::Sqrt>(
make_shared<ngraph::opset1::Maximum>(values, bias_node));
break; break;
} }
case BiasMode::ADD: case BiasMode::ADD:
default: default:
result = make_shared<ngraph::opset1::Sqrt>( result = make_shared<ngraph::opset1::Sqrt>(make_shared<ngraph::opset1::Add>(values, bias_node));
make_shared<ngraph::opset1::Add>(values, bias_node));
} }
return result->add_provenance_group_members_above({value}); return result->add_provenance_group_members_above({value});
} }
@ -125,26 +105,21 @@ namespace ngraph
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
size_t p_norm, size_t p_norm,
float bias, float bias,
bool keep_dims) bool keep_dims) {
{
// The number of non-zero elements // The number of non-zero elements
if (p_norm == 0) if (p_norm == 0) {
{
return opset1::l0_norm(value, reduction_axes, keep_dims); return opset1::l0_norm(value, reduction_axes, keep_dims);
} }
// sum of absolute values. // sum of absolute values.
else if (p_norm == 1) else if (p_norm == 1) {
{
return opset1::l1_norm(value, reduction_axes, bias, keep_dims); return opset1::l1_norm(value, reduction_axes, bias, keep_dims);
} }
// sqrt of sum of squares - Euclidean norm // sqrt of sum of squares - Euclidean norm
else if (p_norm == 2) else if (p_norm == 2) {
{
return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims); return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims);
} }
// generic case // generic case
else else {
{
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims); return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims);
} }
} }

View File

@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "ngraph/builder/reduce_ops.hpp"
#include <numeric> #include <numeric>
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
@ -16,54 +17,39 @@
#include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset1.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder size_t get_num_elements(const Shape& shape, const AxisSet& reduction_axes) {
{
size_t get_num_elements(const Shape& shape, const AxisSet& reduction_axes)
{
size_t N = 1; size_t N = 1;
for (auto a : reduction_axes) for (auto a : reduction_axes) {
{
N *= shape[a]; N *= shape[a];
} }
return N; return N;
} }
std::shared_ptr<Node> get_num_elements(const Output<Node>& value, std::shared_ptr<Node> get_num_elements(const Output<Node>& value, const Output<Node>& reduction_axes) {
const Output<Node>& reduction_axes)
{
const auto value_shape = std::make_shared<ngraph::opset1::ShapeOf>(value); const auto value_shape = std::make_shared<ngraph::opset1::ShapeOf>(value);
const auto dim_values = std::make_shared<ngraph::opset1::Gather>( const auto dim_values =
value_shape, std::make_shared<ngraph::opset1::Gather>(value_shape,
reduction_axes, reduction_axes,
ngraph::opset1::Constant::create(element::i64, {}, {0})); ngraph::opset1::Constant::create(element::i64, {}, {0}));
return std::make_shared<ngraph::opset1::ReduceProd>( return std::make_shared<ngraph::opset1::ReduceProd>(dim_values,
dim_values, ngraph::opset1::Constant::create(element::i64, {}, {0})); ngraph::opset1::Constant::create(element::i64, {}, {0}));
} }
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value, std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value, const AxisSet& reduction_axes, bool keep_dims) {
const AxisSet& reduction_axes,
bool keep_dims)
{
std::shared_ptr<Node> elems_number; std::shared_ptr<Node> elems_number;
const auto value_elem_type = value.get_element_type(); const auto value_elem_type = value.get_element_type();
const auto reduction_axes_const = ngraph::opset1::Constant::create( const auto reduction_axes_const =
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()); ngraph::opset1::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector());
const auto value_elems_sum = const auto value_elems_sum = std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes_const, keep_dims);
std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes_const, keep_dims); if (value.get_partial_shape().is_static()) {
if (value.get_partial_shape().is_static())
{
const auto elems_number_value = get_num_elements(value.get_shape(), reduction_axes); const auto elems_number_value = get_num_elements(value.get_shape(), reduction_axes);
elems_number = ngraph::opset1::Constant::create( elems_number = ngraph::opset1::Constant::create(value_elem_type, Shape{}, {elems_number_value});
value_elem_type, Shape{}, {elems_number_value}); } else {
}
else
{
elems_number = get_num_elements(value, reduction_axes_const); elems_number = get_num_elements(value, reduction_axes_const);
elems_number = elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
} }
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number) return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
@ -72,12 +58,10 @@ namespace ngraph
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value, std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims) bool keep_dims) {
{
std::shared_ptr<Node> elems_number; std::shared_ptr<Node> elems_number;
const auto value_elem_type = value.get_element_type(); const auto value_elem_type = value.get_element_type();
const auto value_elems_sum = const auto value_elems_sum = std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes, keep_dims);
std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes, keep_dims);
elems_number = get_num_elements(value, reduction_axes); elems_number = get_num_elements(value, reduction_axes);
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type); elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
@ -87,8 +71,7 @@ namespace ngraph
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value, std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
const AxisSet& reduction_axes, const AxisSet& reduction_axes,
const bool bessel_correction) const bool bessel_correction) {
{
const bool keep_dims = true; const bool keep_dims = true;
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims); std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
@ -96,21 +79,17 @@ namespace ngraph
diff = std::make_shared<ngraph::opset1::ReduceSum>( diff = std::make_shared<ngraph::opset1::ReduceSum>(
std::make_shared<ngraph::opset1::Multiply>(diff, diff), std::make_shared<ngraph::opset1::Multiply>(diff, diff),
ngraph::opset1::Constant::create( ngraph::opset1::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()),
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()),
false); false);
const auto& et = value.get_element_type(); const auto& et = value.get_element_type();
const auto N = get_num_elements(value.get_shape(), reduction_axes); const auto N = get_num_elements(value.get_shape(), reduction_axes);
std::shared_ptr<Node> result; std::shared_ptr<Node> result;
if (bessel_correction) if (bessel_correction) {
{
const auto N1const = ngraph::opset1::Constant::create(et, Shape{}, {N - 1}); const auto N1const = ngraph::opset1::Constant::create(et, Shape{}, {N - 1});
result = std::make_shared<ngraph::opset1::Divide>(diff, N1const); result = std::make_shared<ngraph::opset1::Divide>(diff, N1const);
} } else {
else
{
const auto Nconst = ngraph::opset1::Constant::create(et, Shape{}, {N}); const auto Nconst = ngraph::opset1::Constant::create(et, Shape{}, {N});
result = std::make_shared<ngraph::opset1::Divide>(diff, Nconst); result = std::make_shared<ngraph::opset1::Divide>(diff, Nconst);
} }
@ -120,22 +99,21 @@ namespace ngraph
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value, std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims, bool keep_dims,
bool bessel_correction) bool bessel_correction) {
{
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims); std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu); Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
diff = std::make_shared<ngraph::opset1::ReduceSum>( diff = std::make_shared<ngraph::opset1::ReduceSum>(std::make_shared<ngraph::opset1::Multiply>(diff, diff),
std::make_shared<ngraph::opset1::Multiply>(diff, diff), reduction_axes, keep_dims); reduction_axes,
keep_dims);
const auto& et = value.get_element_type(); const auto& et = value.get_element_type();
auto N = get_num_elements(value, reduction_axes); auto N = get_num_elements(value, reduction_axes);
N = std::make_shared<ngraph::opset1::Convert>(N, et); N = std::make_shared<ngraph::opset1::Convert>(N, et);
std::shared_ptr<Node> result; std::shared_ptr<Node> result;
if (bessel_correction) if (bessel_correction) {
{
const auto one = std::make_shared<ngraph::opset1::Constant>(et, Shape{}, 1); const auto one = std::make_shared<ngraph::opset1::Constant>(et, Shape{}, 1);
N = std::make_shared<ngraph::opset1::Subtract>(N, one); N = std::make_shared<ngraph::opset1::Subtract>(N, one);
} }

View File

@ -2,13 +2,14 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "ngraph/builder/reshape.hpp"
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/reduce_prod.hpp" #include "ngraph/op/reduce_prod.hpp"
@ -24,71 +25,54 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape) shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape) {
{ if (value.get_partial_shape().same_scheme(shape)) {
if (value.get_partial_shape().same_scheme(shape))
{
return value.get_node_shared_ptr(); return value.get_node_shared_ptr();
} } else if (is_scalar(shape)) {
else if (is_scalar(shape))
{
auto value_rank = value.get_shape().size(); auto value_rank = value.get_shape().size();
AxisVector axes_vector(value_rank); AxisVector axes_vector(value_rank);
std::iota(axes_vector.begin(), axes_vector.end(), 0); std::iota(axes_vector.begin(), axes_vector.end(), 0);
auto axes = op::Constant::create(element::i64, Shape{value_rank}, axes_vector); auto axes = op::Constant::create(element::i64, Shape{value_rank}, axes_vector);
return std::make_shared<op::Squeeze>(value, axes); return std::make_shared<op::Squeeze>(value, axes);
} } else {
else auto out_pattern =
{ op::Constant::create(element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
auto out_pattern = op::Constant::create(
element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
return make_shared<ngraph::opset1::Reshape>(value, out_pattern, false) return make_shared<ngraph::opset1::Reshape>(value, out_pattern, false)
->add_provenance_group_members_above({value}); ->add_provenance_group_members_above({value});
} }
} }
shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order) shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order) {
{ const auto axes_order_const = op::Constant::create(element::i64,
const auto axes_order_const =
op::Constant::create(element::i64,
Shape{axes_order.size()}, Shape{axes_order.size()},
vector<int64_t>(axes_order.begin(), axes_order.end())); vector<int64_t>(axes_order.begin(), axes_order.end()));
return make_shared<ngraph::opset1::Transpose>(value, axes_order_const) return make_shared<ngraph::opset1::Transpose>(value, axes_order_const)->add_provenance_group_members_above({value});
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value) shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value) {
{
// This part is left to preserve backward compatibility and ensure passing ONNX tests. // This part is left to preserve backward compatibility and ensure passing ONNX tests.
if (value.get_partial_shape().is_static()) if (value.get_partial_shape().is_static()) {
{
vector<size_t> axes_order(value.get_shape().size()); vector<size_t> axes_order(value.get_shape().size());
iota(begin(axes_order), end(axes_order), 0); iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order)); reverse(begin(axes_order), end(axes_order));
return builder::opset1::reorder_axes(value, axes_order); return builder::opset1::reorder_axes(value, axes_order);
} }
const auto input_rank = const auto input_rank = std::make_shared<ngraph::opset1::ShapeOf>(std::make_shared<ngraph::opset1::ShapeOf>(value));
std::make_shared<ngraph::opset1::ShapeOf>(std::make_shared<ngraph::opset1::ShapeOf>(value));
const auto neg_one = ngraph::opset1::Constant::create(element::i64, Shape{}, {-1}); const auto neg_one = ngraph::opset1::Constant::create(element::i64, Shape{}, {-1});
const auto start_node = std::make_shared<ngraph::opset1::Add>(input_rank, neg_one); const auto start_node = std::make_shared<ngraph::opset1::Add>(input_rank, neg_one);
const auto reverse_axes_order = const auto reverse_axes_order = std::make_shared<ngraph::opset1::Range>(reshape(start_node, Shape{}), // start
std::make_shared<ngraph::opset1::Range>(reshape(start_node, Shape{}), // start
neg_one, // stop (exclusive) neg_one, // stop (exclusive)
neg_one); // step neg_one); // step
return std::make_shared<ngraph::opset1::Transpose>(value, reverse_axes_order) return std::make_shared<ngraph::opset1::Transpose>(value, reverse_axes_order)
->add_provenance_group_members_above({value}); ->add_provenance_group_members_above({value});
} }
namespace ngraph namespace ngraph {
{ namespace builder {
namespace builder namespace opset1 {
{ namespace {
namespace opset1
{
namespace
{
/// ///
/// \brief Return the node representing normalized axis with respect to /// \brief Return the node representing normalized axis with respect to
/// provided rank. /// provided rank.
@ -98,14 +82,10 @@ namespace ngraph
/// ///
/// \return The new Constant node representing normalized axis value. /// \return The new Constant node representing normalized axis value.
/// ///
std::shared_ptr<Node> std::shared_ptr<Node> get_normalized_axis_node(const std::shared_ptr<Node> node_rank, int64_t axis) {
get_normalized_axis_node(const std::shared_ptr<Node> node_rank, int64_t axis) auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{1}, {axis});
{
auto axis_node =
ngraph::opset1::Constant::create(element::i64, Shape{1}, {axis});
// shortcut for alredy positive value // shortcut for alredy positive value
if (axis >= 0) if (axis >= 0) {
{
return axis_node; return axis_node;
} }
@ -118,47 +98,40 @@ namespace ngraph
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis) shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis) {
{
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of // First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
// input tensor. The last dimension is the product of the rest of input tensor dimensions: // input tensor. The last dimension is the product of the rest of input tensor dimensions:
// [d_{axis}, ..., d_n] // [d_{axis}, ..., d_n]
shared_ptr<Node> output_shape; shared_ptr<Node> output_shape;
if (axis == 0) if (axis == 0) {
{
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {1, -1}); output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {1, -1});
} } else if (axis == 1) {
else if (axis == 1)
{
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {0, -1}); output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {0, -1});
} } else {
else
{
const auto value_shape = make_shared<ngraph::opset1::ShapeOf>(value); const auto value_shape = make_shared<ngraph::opset1::ShapeOf>(value);
const auto value_rank = make_shared<ngraph::opset1::ShapeOf>(value_shape); const auto value_rank = make_shared<ngraph::opset1::ShapeOf>(value_shape);
const auto axis_node = get_normalized_axis_node(value_rank, axis); const auto axis_node = get_normalized_axis_node(value_rank, axis);
const auto first_part_dims = make_shared<ngraph::opset1::StridedSlice>( const auto first_part_dims =
value_shape, make_shared<ngraph::opset1::StridedSlice>(value_shape,
ngraph::opset1::Constant::create(element::i64, {1}, {0}), ngraph::opset1::Constant::create(element::i64, {1}, {0}),
axis_node, axis_node,
vector<int64_t>{}, vector<int64_t>{},
vector<int64_t>{}); vector<int64_t>{});
const auto first_part_dims_length = make_shared<ngraph::opset1::ReduceProd>( const auto first_part_dims_length =
first_part_dims, ngraph::opset1::Constant::create(element::i64, {}, {0}), true); make_shared<ngraph::opset1::ReduceProd>(first_part_dims,
ngraph::opset1::Constant::create(element::i64, {}, {0}),
true);
const auto remaining_part_length = const auto remaining_part_length = ngraph::opset1::Constant::create(element::i64, {1}, {-1});
ngraph::opset1::Constant::create(element::i64, {1}, {-1});
output_shape = make_shared<ngraph::opset1::Concat>( output_shape =
OutputVector{first_part_dims_length, remaining_part_length}, 0); make_shared<ngraph::opset1::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
} }
return make_shared<ngraph::opset1::Reshape>(value, output_shape, true) return make_shared<ngraph::opset1::Reshape>(value, output_shape, true)->add_provenance_group_members_above({value});
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis) shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis) {
{
Shape output_shape(value.get_shape()); Shape output_shape(value.get_shape());
// Add empty axis at specified position. // Add empty axis at specified position.
auto empty_axis_it = begin(output_shape); auto empty_axis_it = begin(output_shape);
@ -167,40 +140,30 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
return builder::opset1::reshape(value, output_shape); return builder::opset1::reshape(value, output_shape);
} }
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes) shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes) {
{ if (axes.empty()) {
if (axes.empty())
{
return value.get_node_shared_ptr(); return value.get_node_shared_ptr();
} }
Shape in_shape{value.get_shape()}; Shape in_shape{value.get_shape()};
for (size_t idx = 0; idx < axes.size(); ++idx) for (size_t idx = 0; idx < axes.size(); ++idx) {
{
in_shape.at(axes.at(idx)) = 0; in_shape.at(axes.at(idx)) = 0;
} }
Shape output_shape; Shape output_shape;
for (auto axis : in_shape) for (auto axis : in_shape) {
{ if (axis != 0) {
if (axis != 0)
{
output_shape.push_back(axis); output_shape.push_back(axis);
} }
} }
return builder::opset1::reshape(value, output_shape); return builder::opset1::reshape(value, output_shape);
} }
shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value, shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value, const size_t start_axis, const size_t end_axis) {
const size_t start_axis, if (start_axis == end_axis) {
const size_t end_axis)
{
if (start_axis == end_axis)
{
return value.get_node_shared_ptr(); return value.get_node_shared_ptr();
} }
if (value.get_partial_shape().is_static()) if (value.get_partial_shape().is_static()) {
{
auto shape = value.get_shape(); auto shape = value.get_shape();
// Multiply all alements of shape from start_axis to end_axis inclusive // Multiply all alements of shape from start_axis to end_axis inclusive
size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis), size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis),
@ -220,22 +183,20 @@ shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value,
// Split lengths used in VariadicSplit // Split lengths used in VariadicSplit
const auto start_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {start_axis}); const auto start_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {start_axis});
const auto end_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {end_axis + 1}); const auto end_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {end_axis + 1});
const auto collapsed_axis = const auto collapsed_axis = make_shared<ngraph::opset1::Subtract>(end_axis_node, start_axis_node);
make_shared<ngraph::opset1::Subtract>(end_axis_node, start_axis_node);
const auto post_axis = make_shared<ngraph::opset1::Subtract>(rank, end_axis_node); const auto post_axis = make_shared<ngraph::opset1::Subtract>(rank, end_axis_node);
const auto split_lengths = make_shared<ngraph::opset1::Concat>( const auto split_lengths =
OutputVector{start_axis_node, collapsed_axis, post_axis}, 0); make_shared<ngraph::opset1::Concat>(OutputVector{start_axis_node, collapsed_axis, post_axis}, 0);
const auto split_axis = ngraph::opset1::Constant::create(element::i64, {}, {0}); const auto split_axis = ngraph::opset1::Constant::create(element::i64, {}, {0});
const auto split_node = const auto split_node = make_shared<ngraph::opset1::VariadicSplit>(shape, split_axis, split_lengths);
make_shared<ngraph::opset1::VariadicSplit>(shape, split_axis, split_lengths);
const auto reduced_axis = ngraph::opset1::Constant::create(element::i64, {1}, {0}); const auto reduced_axis = ngraph::opset1::Constant::create(element::i64, {1}, {0});
const auto collapsed_axis_size = const auto collapsed_axis_size = make_shared<ngraph::opset1::ReduceProd>(split_node->output(1), reduced_axis, true);
make_shared<ngraph::opset1::ReduceProd>(split_node->output(1), reduced_axis, true);
const auto collapsed_shape = make_shared<ngraph::opset1::Concat>( const auto collapsed_shape = make_shared<ngraph::opset1::Concat>(
OutputVector{split_node->output(0), collapsed_axis_size, split_node->output(2)}, 0); OutputVector{split_node->output(0), collapsed_axis_size, split_node->output(2)},
0);
return make_shared<ngraph::opset1::Reshape>(value, collapsed_shape, false); return make_shared<ngraph::opset1::Reshape>(value, collapsed_shape, false);
} }

View File

@ -3,25 +3,21 @@
// //
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset1.hpp"
using namespace ngraph; using namespace ngraph;
OutputVector builder::opset1::split(const Output<Node>& value, OutputVector builder::opset1::split(const Output<Node>& value, const std::vector<size_t>& split_lengths, int64_t axis) {
const std::vector<size_t>& split_lengths,
int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis}); const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
const auto split_lengths_node = const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths); ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split = const auto variadic_split = std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return variadic_split->outputs(); return variadic_split->outputs();
} }
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis) OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis) {
{
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis}); const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits); const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);

View File

@ -12,8 +12,7 @@
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
/// ///
namespace ngraph namespace ngraph {
{
class AttributeVisitor; class AttributeVisitor;
/// \brief Provides access to an attribute of type AT as a value accessor type VAT /// \brief Provides access to an attribute of type AT as a value accessor type VAT
@ -26,8 +25,7 @@ namespace ngraph
/// All ValueAccessors must be derived from ValueAccessor<void> so that an AttributeVisitor /// All ValueAccessors must be derived from ValueAccessor<void> so that an AttributeVisitor
/// only needs to implement a subset of the on_adapter methods. /// only needs to implement a subset of the on_adapter methods.
template <> template <>
class NGRAPH_API ValueAccessor<void> class NGRAPH_API ValueAccessor<void> {
{
public: public:
/// \brief type info enables identification of the value accessor, as well as is_type and /// \brief type info enables identification of the value accessor, as well as is_type and
/// as_type. /// as_type.
@ -45,8 +43,7 @@ namespace ngraph
/// changed. /// changed.
/// \tparam VAT The adapter value type; may be wider than the value being accessed. /// \tparam VAT The adapter value type; may be wider than the value being accessed.
template <typename VAT> template <typename VAT>
class ValueAccessor : public ValueAccessor<void> class ValueAccessor : public ValueAccessor<void> {
{
public: public:
/// Returns the value /// Returns the value
virtual const VAT& get() = 0; virtual const VAT& get() = 0;
@ -55,50 +52,41 @@ namespace ngraph
}; };
template <> template <>
class ValueAccessor<void*> : public ValueAccessor<void> class ValueAccessor<void*> : public ValueAccessor<void> {
{
public: public:
virtual void* get_ptr() = 0; virtual void* get_ptr() = 0;
virtual size_t size() = 0; virtual size_t size() = 0;
}; };
template <typename AT> template <typename AT>
class DirectValueAccessor : public ValueAccessor<AT> class DirectValueAccessor : public ValueAccessor<AT> {
{
public: public:
DirectValueAccessor(AT& ref) DirectValueAccessor(AT& ref) : m_ref(ref) {}
: m_ref(ref) const AT& get() override {
{ return m_ref;
}
void set(const AT& value) override {
m_ref = value;
} }
const AT& get() override { return m_ref; }
void set(const AT& value) override { m_ref = value; }
protected: protected:
AT& m_ref; AT& m_ref;
}; };
template <typename AT, typename VAT> template <typename AT, typename VAT>
class IndirectScalarValueAccessor : public ValueAccessor<VAT> class IndirectScalarValueAccessor : public ValueAccessor<VAT> {
{
public: public:
IndirectScalarValueAccessor(AT& ref) IndirectScalarValueAccessor(AT& ref) : m_ref(ref), m_buffer() {}
: m_ref(ref)
, m_buffer()
{
}
const VAT& get() override const VAT& get() override {
{ if (!m_buffer_valid) {
if (!m_buffer_valid)
{
m_buffer = static_cast<VAT>(m_ref); m_buffer = static_cast<VAT>(m_ref);
m_buffer_valid = true; m_buffer_valid = true;
} }
return m_buffer; return m_buffer;
} }
void set(const VAT& value) override void set(const VAT& value) override {
{
m_ref = static_cast<AT>(value); m_ref = static_cast<AT>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
@ -110,43 +98,35 @@ namespace ngraph
}; };
template <typename A, typename B> template <typename A, typename B>
A copy_from(B& b) A copy_from(B& b) {
{
A result(b.size()); A result(b.size());
for (size_t i = 0; i < b.size(); ++i) for (size_t i = 0; i < b.size(); ++i) {
{ result[i] = static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
result[i] =
static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
} }
return result; return result;
} }
template <typename AT, typename VAT> template <typename AT, typename VAT>
class IndirectVectorValueAccessor : public ValueAccessor<VAT> class IndirectVectorValueAccessor : public ValueAccessor<VAT> {
{
public: public:
IndirectVectorValueAccessor(AT& ref) IndirectVectorValueAccessor(AT& ref) : m_ref(ref) {}
: m_ref(ref)
{
}
const VAT& get() override const VAT& get() override {
{ if (!m_buffer_valid) {
if (!m_buffer_valid)
{
m_buffer = copy_from<typename std::remove_cv<VAT>::type>(m_ref); m_buffer = copy_from<typename std::remove_cv<VAT>::type>(m_ref);
m_buffer_valid = true; m_buffer_valid = true;
} }
return m_buffer; return m_buffer;
} }
void set(const VAT& value) override void set(const VAT& value) override {
{
m_ref = copy_from<AT>(value); m_ref = copy_from<AT>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
operator AT&() { return m_ref; } operator AT&() {
return m_ref;
}
protected: protected:
AT& m_ref; AT& m_ref;
@ -157,236 +137,202 @@ namespace ngraph
/// \brief An AttributeAdapter "captures" an attribute as an AT& and makes it available as a /// \brief An AttributeAdapter "captures" an attribute as an AT& and makes it available as a
/// ValueAccessor<VAT>. /// ValueAccessor<VAT>.
template <typename AT> template <typename AT>
class AttributeAdapter class AttributeAdapter {};
{
};
/// \brief Access an enum via a string /// \brief Access an enum via a string
/// \tparam AT The attribute type enum class /// \tparam AT The attribute type enum class
template <typename AT> template <typename AT>
class EnumAttributeAdapterBase : public ValueAccessor<std::string> class EnumAttributeAdapterBase : public ValueAccessor<std::string> {
{
public: public:
EnumAttributeAdapterBase(AT& value) EnumAttributeAdapterBase(AT& value) : m_ref(value) {}
: m_ref(value)
{
}
const std::string& get() override { return as_string(m_ref); } const std::string& get() override {
void set(const std::string& value) override { m_ref = as_enum<AT>(value); } return as_string(m_ref);
operator AT&() { return m_ref; } }
void set(const std::string& value) override {
m_ref = as_enum<AT>(value);
}
operator AT&() {
return m_ref;
}
protected: protected:
AT& m_ref; AT& m_ref;
}; };
/// Adapters will see visitor /// Adapters will see visitor
class VisitorAdapter : public ValueAccessor<void> class VisitorAdapter : public ValueAccessor<void> {
{
public: public:
virtual bool visit_attributes(AttributeVisitor& visitor) = 0; virtual bool visit_attributes(AttributeVisitor& visitor) = 0;
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<float> : public IndirectScalarValueAccessor<float, double> class NGRAPH_API AttributeAdapter<float> : public IndirectScalarValueAccessor<float, double> {
{
public: public:
AttributeAdapter(float& value) AttributeAdapter(float& value) : IndirectScalarValueAccessor<float, double>(value) {}
: IndirectScalarValueAccessor<float, double>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<float>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<float>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a double as a double /// \brief Access a double as a double
template <> template <>
class NGRAPH_API AttributeAdapter<double> : public DirectValueAccessor<double> class NGRAPH_API AttributeAdapter<double> : public DirectValueAccessor<double> {
{
public: public:
AttributeAdapter(double& value) AttributeAdapter(double& value) : DirectValueAccessor<double>(value) {}
: DirectValueAccessor<double>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<double>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<double>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a string as a string /// \brief Access a string as a string
template <> template <>
class NGRAPH_API AttributeAdapter<std::string> : public DirectValueAccessor<std::string> class NGRAPH_API AttributeAdapter<std::string> : public DirectValueAccessor<std::string> {
{
public: public:
AttributeAdapter(std::string& value) AttributeAdapter(std::string& value) : DirectValueAccessor<std::string>(value) {}
: DirectValueAccessor<std::string>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<string>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<string>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a bool as a bool /// \brief Access a bool as a bool
template <> template <>
class NGRAPH_API AttributeAdapter<bool> : public DirectValueAccessor<bool> class NGRAPH_API AttributeAdapter<bool> : public DirectValueAccessor<bool> {
{
public: public:
AttributeAdapter(bool& value) AttributeAdapter(bool& value) : DirectValueAccessor<bool>(value) {}
: DirectValueAccessor<bool>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<bool>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<bool>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access an int8_t and an int64_t /// \brief Access an int8_t and an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<int8_t> : public IndirectScalarValueAccessor<int8_t, int64_t> class NGRAPH_API AttributeAdapter<int8_t> : public IndirectScalarValueAccessor<int8_t, int64_t> {
{
public: public:
AttributeAdapter(int8_t& value) AttributeAdapter(int8_t& value) : IndirectScalarValueAccessor<int8_t, int64_t>(value) {}
: IndirectScalarValueAccessor<int8_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int8_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int8_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access an int16_t as an int64_t /// \brief Access an int16_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<int16_t> class NGRAPH_API AttributeAdapter<int16_t> : public IndirectScalarValueAccessor<int16_t, int64_t> {
: public IndirectScalarValueAccessor<int16_t, int64_t>
{
public: public:
AttributeAdapter(int16_t& value) AttributeAdapter(int16_t& value) : IndirectScalarValueAccessor<int16_t, int64_t>(value) {}
: IndirectScalarValueAccessor<int16_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int16_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int16_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access an int32_t as an int64_t /// \brief Access an int32_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<int32_t> class NGRAPH_API AttributeAdapter<int32_t> : public IndirectScalarValueAccessor<int32_t, int64_t> {
: public IndirectScalarValueAccessor<int32_t, int64_t>
{
public: public:
AttributeAdapter(int32_t& value) AttributeAdapter(int32_t& value) : IndirectScalarValueAccessor<int32_t, int64_t>(value) {}
: IndirectScalarValueAccessor<int32_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int32_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int32_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access an int64_t as an int64_t /// \brief Access an int64_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<int64_t> : public DirectValueAccessor<int64_t> class NGRAPH_API AttributeAdapter<int64_t> : public DirectValueAccessor<int64_t> {
{
public: public:
AttributeAdapter(int64_t& value) AttributeAdapter(int64_t& value) : DirectValueAccessor<int64_t>(value) {}
: DirectValueAccessor<int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int64_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int64_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a uint8_t as an int64_t /// \brief Access a uint8_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<uint8_t> class NGRAPH_API AttributeAdapter<uint8_t> : public IndirectScalarValueAccessor<uint8_t, int64_t> {
: public IndirectScalarValueAccessor<uint8_t, int64_t>
{
public: public:
AttributeAdapter(uint8_t& value) AttributeAdapter(uint8_t& value) : IndirectScalarValueAccessor<uint8_t, int64_t>(value) {}
: IndirectScalarValueAccessor<uint8_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint8_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint8_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a uint16_t as an int64_t /// \brief Access a uint16_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<uint16_t> class NGRAPH_API AttributeAdapter<uint16_t> : public IndirectScalarValueAccessor<uint16_t, int64_t> {
: public IndirectScalarValueAccessor<uint16_t, int64_t>
{
public: public:
AttributeAdapter(uint16_t& value) AttributeAdapter(uint16_t& value) : IndirectScalarValueAccessor<uint16_t, int64_t>(value) {}
: IndirectScalarValueAccessor<uint16_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint16_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint16_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a uint32_t as an int64_t /// \brief Access a uint32_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<uint32_t> class NGRAPH_API AttributeAdapter<uint32_t> : public IndirectScalarValueAccessor<uint32_t, int64_t> {
: public IndirectScalarValueAccessor<uint32_t, int64_t>
{
public: public:
AttributeAdapter(uint32_t& value) AttributeAdapter(uint32_t& value) : IndirectScalarValueAccessor<uint32_t, int64_t>(value) {}
: IndirectScalarValueAccessor<uint32_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint32_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint32_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a uint64_t as an int64_t /// \brief Access a uint64_t as an int64_t
template <> template <>
class NGRAPH_API AttributeAdapter<uint64_t> class NGRAPH_API AttributeAdapter<uint64_t> : public IndirectScalarValueAccessor<uint64_t, int64_t> {
: public IndirectScalarValueAccessor<uint64_t, int64_t>
{
public: public:
AttributeAdapter(uint64_t& value) AttributeAdapter(uint64_t& value) : IndirectScalarValueAccessor<uint64_t, int64_t>(value) {}
: IndirectScalarValueAccessor<uint64_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint64_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint64_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
#ifdef __APPLE__ #ifdef __APPLE__
// size_t is one of the uint types on _WIN32 // size_t is one of the uint types on _WIN32
template <> template <>
class NGRAPH_API AttributeAdapter<size_t> : public IndirectScalarValueAccessor<size_t, int64_t> class NGRAPH_API AttributeAdapter<size_t> : public IndirectScalarValueAccessor<size_t, int64_t> {
{
public: public:
AttributeAdapter(size_t& value) AttributeAdapter(size_t& value) : IndirectScalarValueAccessor<size_t, int64_t>(value) {}
: IndirectScalarValueAccessor<size_t, int64_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<size_t>> class NGRAPH_API AttributeAdapter<std::vector<size_t>>
: public IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>> : public IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>> {
{
public: public:
AttributeAdapter(std::vector<size_t>& value) AttributeAdapter(std::vector<size_t>& value)
: IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>(value) : IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>(value) {}
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<size_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<size_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
#endif #endif
@ -395,166 +341,133 @@ namespace ngraph
/// \brief Access a vector<int8_t> /// \brief Access a vector<int8_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<int8_t>> class NGRAPH_API AttributeAdapter<std::vector<int8_t>> : public DirectValueAccessor<std::vector<int8_t>> {
: public DirectValueAccessor<std::vector<int8_t>>
{
public: public:
AttributeAdapter(std::vector<int8_t>& value) AttributeAdapter(std::vector<int8_t>& value) : DirectValueAccessor<std::vector<int8_t>>(value) {}
: DirectValueAccessor<std::vector<int8_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int8_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int8_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<int16_t> /// \brief Access a vector<int16_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<int16_t>> class NGRAPH_API AttributeAdapter<std::vector<int16_t>> : public DirectValueAccessor<std::vector<int16_t>> {
: public DirectValueAccessor<std::vector<int16_t>>
{
public: public:
AttributeAdapter(std::vector<int16_t>& value) AttributeAdapter(std::vector<int16_t>& value) : DirectValueAccessor<std::vector<int16_t>>(value) {}
: DirectValueAccessor<std::vector<int16_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int16_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int16_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<int32_t> /// \brief Access a vector<int32_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<int32_t>> class NGRAPH_API AttributeAdapter<std::vector<int32_t>> : public DirectValueAccessor<std::vector<int32_t>> {
: public DirectValueAccessor<std::vector<int32_t>>
{
public: public:
AttributeAdapter(std::vector<int32_t>& value) AttributeAdapter(std::vector<int32_t>& value) : DirectValueAccessor<std::vector<int32_t>>(value) {}
: DirectValueAccessor<std::vector<int32_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int32_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int32_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<int64_t> /// \brief Access a vector<int64_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<int64_t>> class NGRAPH_API AttributeAdapter<std::vector<int64_t>> : public DirectValueAccessor<std::vector<int64_t>> {
: public DirectValueAccessor<std::vector<int64_t>>
{
public: public:
AttributeAdapter(std::vector<int64_t>& value) AttributeAdapter(std::vector<int64_t>& value) : DirectValueAccessor<std::vector<int64_t>>(value) {}
: DirectValueAccessor<std::vector<int64_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int64_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int64_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<uint8_t> /// \brief Access a vector<uint8_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<uint8_t>> class NGRAPH_API AttributeAdapter<std::vector<uint8_t>> : public DirectValueAccessor<std::vector<uint8_t>> {
: public DirectValueAccessor<std::vector<uint8_t>>
{
public: public:
AttributeAdapter(std::vector<uint8_t>& value) AttributeAdapter(std::vector<uint8_t>& value) : DirectValueAccessor<std::vector<uint8_t>>(value) {}
: DirectValueAccessor<std::vector<uint8_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint8_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint8_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<uint16_t> /// \brief Access a vector<uint16_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<uint16_t>> class NGRAPH_API AttributeAdapter<std::vector<uint16_t>> : public DirectValueAccessor<std::vector<uint16_t>> {
: public DirectValueAccessor<std::vector<uint16_t>>
{
public: public:
AttributeAdapter(std::vector<uint16_t>& value) AttributeAdapter(std::vector<uint16_t>& value) : DirectValueAccessor<std::vector<uint16_t>>(value) {}
: DirectValueAccessor<std::vector<uint16_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint16_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint16_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<uint32_t> /// \brief Access a vector<uint32_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<uint32_t>> class NGRAPH_API AttributeAdapter<std::vector<uint32_t>> : public DirectValueAccessor<std::vector<uint32_t>> {
: public DirectValueAccessor<std::vector<uint32_t>>
{
public: public:
AttributeAdapter(std::vector<uint32_t>& value) AttributeAdapter(std::vector<uint32_t>& value) : DirectValueAccessor<std::vector<uint32_t>>(value) {}
: DirectValueAccessor<std::vector<uint32_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint32_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint32_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<uint64_t> /// \brief Access a vector<uint64_t>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<uint64_t>> class NGRAPH_API AttributeAdapter<std::vector<uint64_t>> : public DirectValueAccessor<std::vector<uint64_t>> {
: public DirectValueAccessor<std::vector<uint64_t>>
{
public: public:
AttributeAdapter(std::vector<uint64_t>& value) AttributeAdapter(std::vector<uint64_t>& value) : DirectValueAccessor<std::vector<uint64_t>>(value) {}
: DirectValueAccessor<std::vector<uint64_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint64_t>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint64_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<float> /// \brief Access a vector<float>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<float>> class NGRAPH_API AttributeAdapter<std::vector<float>> : public DirectValueAccessor<std::vector<float>> {
: public DirectValueAccessor<std::vector<float>>
{
public: public:
AttributeAdapter(std::vector<float>& value) AttributeAdapter(std::vector<float>& value) : DirectValueAccessor<std::vector<float>>(value) {}
: DirectValueAccessor<std::vector<float>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<double> /// \brief Access a vector<double>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<double>> class NGRAPH_API AttributeAdapter<std::vector<double>> : public DirectValueAccessor<std::vector<double>> {
: public DirectValueAccessor<std::vector<double>>
{
public: public:
AttributeAdapter(std::vector<double>& value) AttributeAdapter(std::vector<double>& value) : DirectValueAccessor<std::vector<double>>(value) {}
: DirectValueAccessor<std::vector<double>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<double>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<double>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
/// \brief Access a vector<string> /// \brief Access a vector<string>
template <> template <>
class NGRAPH_API AttributeAdapter<std::vector<std::string>> class NGRAPH_API AttributeAdapter<std::vector<std::string>> : public DirectValueAccessor<std::vector<std::string>> {
: public DirectValueAccessor<std::vector<std::string>>
{
public: public:
AttributeAdapter(std::vector<std::string>& value) AttributeAdapter(std::vector<std::string>& value) : DirectValueAccessor<std::vector<std::string>>(value) {}
: DirectValueAccessor<std::vector<std::string>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
} // namespace ngraph } // namespace ngraph

View File

@ -12,8 +12,7 @@
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph {
{
template <typename T> template <typename T>
class ValueAccessor; class ValueAccessor;
class VisitorAdapter; class VisitorAdapter;
@ -55,8 +54,7 @@ namespace ngraph
/// registered with the visitor using register_node, which needs a shared pointer to a node and /// registered with the visitor using register_node, which needs a shared pointer to a node and
/// a string ID. The ID string will be used to serialize the node or find the node during /// a string ID. The ID string will be used to serialize the node or find the node during
/// deserialization. /// deserialization.
class NGRAPH_API AttributeVisitor class NGRAPH_API AttributeVisitor {
{
public: public:
virtual ~AttributeVisitor() {} virtual ~AttributeVisitor() {}
// Must implement these methods // Must implement these methods
@ -80,49 +78,38 @@ namespace ngraph
virtual void on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<float>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<float>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter);
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int8_t>>& adapter);
ValueAccessor<std::vector<int8_t>>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int16_t>>& adapter);
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int32_t>>& adapter);
ValueAccessor<std::vector<int16_t>>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int64_t>>& adapter);
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint8_t>>& adapter);
ValueAccessor<std::vector<int32_t>>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint16_t>>& adapter);
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint32_t>>& adapter);
ValueAccessor<std::vector<int64_t>>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint64_t>>& adapter);
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter);
ValueAccessor<std::vector<uint8_t>>& adapter); virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<double>>& adapter);
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<std::string>>& adapter);
ValueAccessor<std::vector<uint16_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint32_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint64_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<float>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<double>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter);
/// \brief Hook for adapters that need visitor access /// \brief Hook for adapters that need visitor access
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter); virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor /// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
/// \param name attribute name /// \param name attribute name
/// \param adapter reference to a Function ValueAccessor<VAT> /// \param adapter reference to a Function ValueAccessor<VAT>
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name, ValueAccessor<std::shared_ptr<Function>>& adapter);
ValueAccessor<std::shared_ptr<Function>>& adapter);
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert /// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
/// to a ValueAccessor<U> for one of the on_adpater methods. /// to a ValueAccessor<U> for one of the on_adpater methods.
template <typename AT> template <typename AT>
void on_attribute(const std::string& name, AT& value) void on_attribute(const std::string& name, AT& value) {
{
AttributeAdapter<AT> adapter(value); AttributeAdapter<AT> adapter(value);
start_structure(name); start_structure(name);
on_adapter(get_name_with_context(), adapter); on_adapter(get_name_with_context(), adapter);
finish_structure(); finish_structure();
} }
/// \returns The nested context of visits /// \returns The nested context of visits
const std::vector<std::string>& get_context() const { return m_context; } const std::vector<std::string>& get_context() const {
return m_context;
}
/// \returns context prepended to names /// \returns context prepended to names
virtual std::string get_name_with_context(); virtual std::string get_name_with_context();
/// \brief Start visiting a nested structure /// \brief Start visiting a nested structure
@ -135,8 +122,7 @@ namespace ngraph
/// ///
/// No node may be used as an attribute unless it has already been registered with an ID. /// No node may be used as an attribute unless it has already been registered with an ID.
/// References to nodes are visited with a ValueAccessor of their ID. /// References to nodes are visited with a ValueAccessor of their ID.
virtual void register_node(const std::shared_ptr<Node>& node, virtual void register_node(const std::shared_ptr<Node>& node, node_id_t id = invalid_node_id);
node_id_t id = invalid_node_id);
/// Returns the node with the given id, or nullptr if there is no registered node /// Returns the node with the given id, or nullptr if there is no registered node
virtual std::shared_ptr<Node> get_registered_node(node_id_t id); virtual std::shared_ptr<Node> get_registered_node(node_id_t id);
/// Returns the id for the node, or -1 if the node is not registered /// Returns the id for the node, or -1 if the node is not registered

View File

@ -12,11 +12,9 @@
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph {
{
/// \brief A set of axes. /// \brief A set of axes.
class AxisSet : public std::set<size_t> class AxisSet : public std::set<size_t> {
{
public: public:
NGRAPH_API AxisSet(); NGRAPH_API AxisSet();
@ -36,19 +34,19 @@ namespace ngraph
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<AxisSet> : public ValueAccessor<std::vector<int64_t>> class NGRAPH_API AttributeAdapter<AxisSet> : public ValueAccessor<std::vector<int64_t>> {
{
public: public:
AttributeAdapter(AxisSet& value) AttributeAdapter(AxisSet& value) : m_ref(value) {}
: m_ref(value)
{
}
const std::vector<int64_t>& get() override; const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override; void set(const std::vector<int64_t>& value) override;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
operator AxisSet&() { return m_ref; } return type_info;
}
operator AxisSet&() {
return m_ref;
}
protected: protected:
AxisSet& m_ref; AxisSet& m_ref;

View File

@ -11,11 +11,9 @@
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph {
{
/// \brief A vector of axes. /// \brief A vector of axes.
class AxisVector : public std::vector<size_t> class AxisVector : public std::vector<size_t> {
{
public: public:
NGRAPH_API AxisVector(const std::initializer_list<size_t>& axes); NGRAPH_API AxisVector(const std::initializer_list<size_t>& axes);
@ -26,10 +24,7 @@ namespace ngraph
NGRAPH_API explicit AxisVector(size_t n); NGRAPH_API explicit AxisVector(size_t n);
template <class InputIterator> template <class InputIterator>
AxisVector(InputIterator first, InputIterator last) AxisVector(InputIterator first, InputIterator last) : std::vector<size_t>(first, last) {}
: std::vector<size_t>(first, last)
{
}
NGRAPH_API AxisVector(); NGRAPH_API AxisVector();
@ -41,17 +36,14 @@ namespace ngraph
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<AxisVector> class NGRAPH_API AttributeAdapter<AxisVector> : public IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>> {
: public IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>
{
public: public:
AttributeAdapter(AxisVector& value) AttributeAdapter(AxisVector& value) : IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>(value) {}
: IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisVector>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisVector>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
NGRAPH_API NGRAPH_API

View File

@ -10,32 +10,26 @@
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
namespace ngraph namespace ngraph {
{ static inline std::ostream& write_all_to_stream(std::ostream& str) {
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; } return str;
}
template <typename T, typename... TS> template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args) static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args) {
{
return write_all_to_stream(str << arg, args...); return write_all_to_stream(str << arg, args...);
} }
struct CheckLocInfo struct CheckLocInfo {
{
const char* file; const char* file;
int line; int line;
const char* check_string; const char* check_string;
}; };
/// Base class for check failure exceptions. /// Base class for check failure exceptions.
class NGRAPH_API CheckFailure : public ngraph_error class NGRAPH_API CheckFailure : public ngraph_error {
{
public: public:
CheckFailure(const CheckLocInfo& check_loc_info, CheckFailure(const CheckLocInfo& check_loc_info, const std::string& context_info, const std::string& explanation)
const std::string& context_info, : ngraph_error(make_what(check_loc_info, context_info, explanation)) {}
const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation))
{
}
private: private:
static std::string make_what(const CheckLocInfo& check_loc_info, static std::string make_what(const CheckLocInfo& check_loc_info,
@ -110,22 +104,17 @@ namespace ngraph
// variable (ss___) and risk shadowing. // variable (ss___) and risk shadowing.
// //
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \ #define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \
do \ do { \
{ \ if (!(check)) { \
if (!(check)) \
{ \
::std::stringstream ss___; \ ::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \ ::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class( \ throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \ } \
} while (0) } while (0)
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \ #define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \
do \ do { \
{ \ if (!(check)) { \
if (!(check)) \
{ \
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \ throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \
} \ } \
} while (0) } while (0)
@ -143,8 +132,7 @@ namespace ngraph
/// \param ... Additional error message that should describe why that execution path is unreachable. /// \param ... Additional error message that should describe why that execution path is unreachable.
/// \throws ::ngraph::CheckFailure if the macro is executed. /// \throws ::ngraph::CheckFailure if the macro is executed.
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__) #define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__)
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) \ #define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
#define GLUE(x, y) x y #define GLUE(x, y) x y
@ -178,33 +166,7 @@ namespace ngraph
count count
#define EXPAND_ARGS(args) RETURN_ARG_COUNT args #define EXPAND_ARGS(args) RETURN_ARG_COUNT args
#define COUNT_ARGS_MAXN(...) \ #define COUNT_ARGS_MAXN(...) \
EXPAND_ARGS((__VA_ARGS__, \ EXPAND_ARGS((__VA_ARGS__, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0))
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
1, \
0))
#define OVERLOAD_MACRO2(name, count) name##count #define OVERLOAD_MACRO2(name, count) name##count
#define OVERLOAD_MACRO1(name, count) OVERLOAD_MACRO2(name, count) #define OVERLOAD_MACRO1(name, count) OVERLOAD_MACRO2(name, count)

View File

@ -11,11 +11,9 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph {
{
/// \brief Coordinates for a tensor element /// \brief Coordinates for a tensor element
class Coordinate : public std::vector<size_t> class Coordinate : public std::vector<size_t> {
{
public: public:
NGRAPH_API Coordinate(); NGRAPH_API Coordinate();
NGRAPH_API Coordinate(const std::initializer_list<size_t>& axes); NGRAPH_API Coordinate(const std::initializer_list<size_t>& axes);
@ -31,10 +29,7 @@ namespace ngraph
NGRAPH_API ~Coordinate(); NGRAPH_API ~Coordinate();
template <class InputIterator> template <class InputIterator>
Coordinate(InputIterator first, InputIterator last) Coordinate(InputIterator first, InputIterator last) : std::vector<size_t>(first, last) {}
: std::vector<size_t>(first, last)
{
}
NGRAPH_API Coordinate& operator=(const Coordinate& v); NGRAPH_API Coordinate& operator=(const Coordinate& v);
@ -42,17 +37,14 @@ namespace ngraph
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<Coordinate> class NGRAPH_API AttributeAdapter<Coordinate> : public IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>> {
: public IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>
{
public: public:
AttributeAdapter(Coordinate& value) AttributeAdapter(Coordinate& value) : IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>(value) {}
: IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
NGRAPH_API NGRAPH_API

View File

@ -11,11 +11,9 @@
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph {
{
/// \brief A difference (signed) of tensor element coordinates. /// \brief A difference (signed) of tensor element coordinates.
class CoordinateDiff : public std::vector<std::ptrdiff_t> class CoordinateDiff : public std::vector<std::ptrdiff_t> {
{
public: public:
NGRAPH_API CoordinateDiff(const std::initializer_list<std::ptrdiff_t>& diffs); NGRAPH_API CoordinateDiff(const std::initializer_list<std::ptrdiff_t>& diffs);
@ -26,10 +24,7 @@ namespace ngraph
NGRAPH_API explicit CoordinateDiff(size_t n, std::ptrdiff_t initial_value = 0); NGRAPH_API explicit CoordinateDiff(size_t n, std::ptrdiff_t initial_value = 0);
template <class InputIterator> template <class InputIterator>
CoordinateDiff(InputIterator first, InputIterator last) CoordinateDiff(InputIterator first, InputIterator last) : std::vector<std::ptrdiff_t>(first, last) {}
: std::vector<std::ptrdiff_t>(first, last)
{
}
NGRAPH_API ~CoordinateDiff(); NGRAPH_API ~CoordinateDiff();
@ -47,12 +42,12 @@ namespace ngraph
{ {
public: public:
AttributeAdapter(CoordinateDiff& value) AttributeAdapter(CoordinateDiff& value)
: IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>(value) : IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>(value) {}
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
NGRAPH_API NGRAPH_API

View File

@ -10,17 +10,14 @@
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/variant.hpp" #include "ngraph/variant.hpp"
namespace ngraph namespace ngraph {
{
class Node; class Node;
namespace descriptor namespace descriptor {
{
class Output; class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple // Describes a tensor that is an input to an op, directly or indirectly via a tuple
class NGRAPH_API Input class NGRAPH_API Input {
{
friend class ngraph::Node; friend class ngraph::Node;
public: public:
@ -38,23 +35,37 @@ namespace ngraph
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
/// \return the raw pointer to the node that this is an input of /// \return the raw pointer to the node that this is an input of
Node* get_raw_pointer_node() const { return m_node; } Node* get_raw_pointer_node() const {
return m_node;
}
/// \return the position within all supplied tensors of this input /// \return the position within all supplied tensors of this input
size_t get_index() const { return m_index; } size_t get_index() const {
return m_index;
}
/// \return the connected output /// \return the connected output
const Output& get_output() const { return *m_output; } const Output& get_output() const {
return *m_output;
}
/// \return the connected output /// \return the connected output
Output& get_output() { return *m_output; } Output& get_output() {
return *m_output;
}
/// \return true if an output is connected to the input. /// \return true if an output is connected to the input.
bool has_output() const { return m_output != nullptr; } bool has_output() const {
return m_output != nullptr;
}
/// \return the tensor of the connected output /// \return the tensor of the connected output
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
/// \return the tensor of the connected output /// \return the tensor of the connected output
Tensor& get_tensor(); Tensor& get_tensor();
RTMap& get_rt_info() { return m_rt_info; } RTMap& get_rt_info() {
const RTMap& get_rt_info() const { return m_rt_info; } return m_rt_info;
}
const RTMap& get_rt_info() const {
return m_rt_info;
}
/// \brief Replace the current output that supplies a value for this input with output i /// \brief Replace the current output that supplies a value for this input with output i
/// of node /// of node
@ -69,12 +80,16 @@ namespace ngraph
/// corresponding node. (Usually this is false.) /// corresponding node. (Usually this is false.)
/// ///
/// See Node::set_input_is_relevant_to_shape for more details. /// See Node::set_input_is_relevant_to_shape for more details.
bool get_is_relevant_to_shape() const { return m_is_relevant_to_shape; } bool get_is_relevant_to_shape() const {
return m_is_relevant_to_shape;
}
/// \return true if the value of this input is relevant to the output value of the /// \return true if the value of this input is relevant to the output value of the
/// corresponding node. (Usually this is true.) /// corresponding node. (Usually this is true.)
/// ///
/// See Node::set_input_is_relevant_to_value for more details. /// See Node::set_input_is_relevant_to_value for more details.
bool get_is_relevant_to_value() const { return m_is_relevant_to_value; } bool get_is_relevant_to_value() const {
return m_is_relevant_to_value;
}
protected: protected:
/// \return the tensor for the connected output /// \return the tensor for the connected output

View File

@ -15,8 +15,7 @@
#include "ngraph/node_output.hpp" #include "ngraph/node_output.hpp"
#include "ngraph/variant.hpp" #include "ngraph/variant.hpp"
namespace ngraph namespace ngraph {
{
// The forward declaration of Node is needed here because Node has a deque of // The forward declaration of Node is needed here because Node has a deque of
// Outputs, and Output is an incomplete type at this point. STL containers of // Outputs, and Output is an incomplete type at this point. STL containers of
// incomplete type have undefined behavior according to the C++11 standard, and // incomplete type have undefined behavior according to the C++11 standard, and
@ -24,19 +23,11 @@ namespace ngraph
// systems (namely macOS). // systems (namely macOS).
class Node; class Node;
namespace descriptor namespace descriptor {
{
// Describes an output tensor of an op // Describes an output tensor of an op
class NGRAPH_API Output class NGRAPH_API Output {
{
public: public:
Output() Output() : m_node(nullptr), m_index(0), m_tensor(nullptr), m_inputs() {}
: m_node(nullptr)
, m_index(0)
, m_tensor(nullptr)
, m_inputs()
{
}
/// \param node Node that owns this output. /// \param node Node that owns this output.
/// \param index Position of the output tensor in all output tensors /// \param index Position of the output tensor in all output tensors
@ -44,17 +35,29 @@ namespace ngraph
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor); Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; } size_t get_index() const {
return m_index;
}
ngraph::Output<Node> get_output() const; ngraph::Output<Node> get_output() const;
std::shared_ptr<Tensor> get_tensor_ptr() const { return m_tensor; } std::shared_ptr<Tensor> get_tensor_ptr() const {
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; } return m_tensor;
}
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) {
m_tensor = tensor;
}
void add_input(Input* input); void add_input(Input* input);
void remove_input(Input* input); void remove_input(Input* input);
const std::vector<Input*>& get_inputs() const { return m_inputs; } const std::vector<Input*>& get_inputs() const {
return m_inputs;
}
Tensor& get_tensor() const; Tensor& get_tensor() const;
RTMap& get_rt_info() { return m_rt_info; } RTMap& get_rt_info() {
const RTMap& get_rt_info() const { return m_rt_info; } return m_rt_info;
}
const RTMap& get_rt_info() const {
return m_rt_info;
}
/// \return the shape of the output /// \return the shape of the output
const Shape& get_shape() const; const Shape& get_shape() const;

View File

@ -14,31 +14,22 @@
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph {
{
class Node; class Node;
namespace runtime namespace runtime {
{
class HostTensor; class HostTensor;
} }
using HostTensorPtr = std::shared_ptr<runtime::HostTensor>; using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
namespace descriptor namespace descriptor {
{
/// \brief Compile-time descriptor of a first-class value that is a tensor. /// \brief Compile-time descriptor of a first-class value that is a tensor.
class NGRAPH_API Tensor class NGRAPH_API Tensor {
{
Tensor(const Tensor&) = delete; Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete;
public: public:
Tensor(const element::Type& element_type, Tensor(const element::Type& element_type, const PartialShape& pshape, const std::string& name);
const PartialShape& pshape, Tensor(const element::Type& element_type, const PartialShape& pshape, Node* node, size_t node_output_number);
const std::string& name);
Tensor(const element::Type& element_type,
const PartialShape& pshape,
Node* node,
size_t node_output_number);
NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.") NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
const std::string& get_name() const; const std::string& get_name() const;
@ -59,16 +50,23 @@ namespace ngraph
/// \brief unsets bound value descriptions /// \brief unsets bound value descriptions
void invalidate_values(); void invalidate_values();
const element::Type& get_element_type() const { return m_element_type; } const element::Type& get_element_type() const {
return m_element_type;
}
const Shape& get_shape() const; const Shape& get_shape() const;
const PartialShape& get_partial_shape() const { return m_partial_shape; } const PartialShape& get_partial_shape() const {
return m_partial_shape;
}
/// \brief gets lower bound value description /// \brief gets lower bound value description
HostTensorPtr get_lower_value() const { return m_lower_value; } HostTensorPtr get_lower_value() const {
return m_lower_value;
}
/// \brief gets upper bound value description /// \brief gets upper bound value description
HostTensorPtr get_upper_value() const { return m_upper_value; } HostTensorPtr get_upper_value() const {
return m_upper_value;
}
/// \brief checks if lower and upper bound are set and point to the same HostTensor /// \brief checks if lower and upper bound are set and point to the same HostTensor
bool has_and_set_bound() const bool has_and_set_bound() const {
{
return m_upper_value != nullptr && m_upper_value == m_lower_value; return m_upper_value != nullptr && m_upper_value == m_lower_value;
} }
size_t size() const; size_t size() const;

View File

@ -4,23 +4,22 @@
#pragma once #pragma once
#include <limits>
#include <stddef.h> #include <stddef.h>
#include <limits>
#include <stdexcept> #include <stdexcept>
#include "ngraph/deprecated.hpp" #include "ngraph/deprecated.hpp"
#include "ngraph/interval.hpp" #include "ngraph/interval.hpp"
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph {
{
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime), /// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object. /// in a shape or shape-like object.
/// ///
/// Static dimensions may be implicitly converted from value_type. A dynamic dimension is /// Static dimensions may be implicitly converted from value_type. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic(). /// constructed with Dimension() or Dimension::dynamic().
class NGRAPH_API Dimension class NGRAPH_API Dimension {
{
public: public:
using value_type = int64_t; using value_type = int64_t;
@ -36,20 +35,22 @@ namespace ngraph
/// \brief Construct a dynamic dimension with range [0, ...] /// \brief Construct a dynamic dimension with range [0, ...]
Dimension() = default; Dimension() = default;
bool operator==(const Dimension& dimension) const bool operator==(const Dimension& dimension) const {
{
return m_dimension == dimension.m_dimension; return m_dimension == dimension.m_dimension;
} }
bool operator!=(const Dimension& dimension) const bool operator!=(const Dimension& dimension) const {
{
return m_dimension != dimension.m_dimension; return m_dimension != dimension.m_dimension;
} }
/// \brief Check whether this dimension is static. /// \brief Check whether this dimension is static.
/// \return `true` if the dimension is static, else `false`. /// \return `true` if the dimension is static, else `false`.
bool is_static() const { return m_dimension.size() == 1; } bool is_static() const {
return m_dimension.size() == 1;
}
/// \brief Check whether this dimension is dynamic. /// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`. /// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return m_dimension.size() != 1; } bool is_dynamic() const {
return m_dimension.size() != 1;
}
/// \brief Convert this dimension to `value_type`. This dimension must be static and /// \brief Convert this dimension to `value_type`. This dimension must be static and
/// non-negative. /// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative. /// \throws std::invalid_argument If this dimension is dynamic or negative.
@ -59,8 +60,12 @@ namespace ngraph
value_type get_max_length() const; value_type get_max_length() const;
/// \brief Return the interval of valid lengths /// \brief Return the interval of valid lengths
const Interval& get_interval() const { return m_dimension; } const Interval& get_interval() const {
Interval& get_interval() { return m_dimension; } return m_dimension;
}
Interval& get_interval() {
return m_dimension;
}
/// \brief Check whether this dimension represents the same scheme as the argument (both /// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal). /// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to. /// \param dim The other dimension to compare this dimension to.
@ -115,7 +120,9 @@ namespace ngraph
/// \brief Create a dynamic dimension. /// \brief Create a dynamic dimension.
/// \return A dynamic dimension. /// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); } static Dimension dynamic() {
return Dimension();
}
/// \brief Addition operator for Dimension. /// \brief Addition operator for Dimension.
/// \param dim Right operand for addition. /// \param dim Right operand for addition.
/// \return Smallest interval dimension enclosing inputs /// \return Smallest interval dimension enclosing inputs
@ -135,21 +142,22 @@ namespace ngraph
/// \brief Add-into operator for Dimension. /// \brief Add-into operator for Dimension.
/// \param dim Right operand for addition. /// \param dim Right operand for addition.
/// \return A reference to `*this`, after updating `*this` to the value `*this + dim`. /// \return A reference to `*this`, after updating `*this` to the value `*this + dim`.
Dimension& operator+=(const Dimension& dim) { return (*this = *this + dim); } Dimension& operator+=(const Dimension& dim) {
return (*this = *this + dim);
}
/// \brief Multiply-into operator for Dimension. /// \brief Multiply-into operator for Dimension.
/// \param dim Right operand for multiplication. /// \param dim Right operand for multiplication.
/// \return A reference to `*this`, after updating `*this` to the value `*this * dim`. /// \return A reference to `*this`, after updating `*this` to the value `*this * dim`.
Dimension& operator*=(const Dimension& dim) { return (*this = *this * dim); } Dimension& operator*=(const Dimension& dim) {
return (*this = *this * dim);
}
/// \brief Intersection of dimensions /// \brief Intersection of dimensions
Dimension operator&(const Dimension& dim) const; Dimension operator&(const Dimension& dim) const;
/// \brief Intersection of dimensions /// \brief Intersection of dimensions
Dimension& operator&=(const Dimension& dim); Dimension& operator&=(const Dimension& dim);
private: private:
Dimension(const Interval& interval) Dimension(const Interval& interval) : m_dimension(interval) {}
: m_dimension(interval)
{
}
// The actual numerical value of the dimension. // The actual numerical value of the dimension.
Interval m_dimension{}; Interval m_dimension{};

View File

@ -12,12 +12,9 @@
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph {
{ namespace reduction {
namespace reduction enum class Type {
{
enum class Type
{
SUM, SUM,
PROD, PROD,
MIN, MIN,
@ -29,16 +26,13 @@ namespace ngraph
} // namespace reduction } // namespace reduction
template <> template <>
class NGRAPH_API AttributeAdapter<reduction::Type> class NGRAPH_API AttributeAdapter<reduction::Type> : public EnumAttributeAdapterBase<reduction::Type> {
: public EnumAttributeAdapterBase<reduction::Type>
{
public: public:
AttributeAdapter(reduction::Type& value) AttributeAdapter(reduction::Type& value) : EnumAttributeAdapterBase<reduction::Type>(value) {}
: EnumAttributeAdapterBase<reduction::Type>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
} // namespace ngraph } // namespace ngraph

View File

@ -10,17 +10,14 @@
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
namespace ngraph namespace ngraph {
{
/// Uses a pairings defined by EnumTypes::get() to convert between strings /// Uses a pairings defined by EnumTypes::get() to convert between strings
/// and enum values. /// and enum values.
template <typename EnumType> template <typename EnumType>
class EnumNames class EnumNames {
{
public: public:
/// Converts strings to enum values /// Converts strings to enum values
static EnumType as_enum(const std::string& name) static EnumType as_enum(const std::string& name) {
{
auto to_lower = [](const std::string& s) { auto to_lower = [](const std::string& s) {
std::string rc = s; std::string rc = s;
std::transform(rc.begin(), rc.end(), rc.begin(), [](char c) { std::transform(rc.begin(), rc.end(), rc.begin(), [](char c) {
@ -28,10 +25,8 @@ namespace ngraph
}); });
return rc; return rc;
}; };
for (const auto& p : get().m_string_enums) for (const auto& p : get().m_string_enums) {
{ if (to_lower(p.first) == to_lower(name)) {
if (to_lower(p.first) == to_lower(name))
{
return p.second; return p.second;
} }
} }
@ -39,12 +34,9 @@ namespace ngraph
} }
/// Converts enum values to strings /// Converts enum values to strings
static const std::string& as_string(EnumType e) static const std::string& as_string(EnumType e) {
{ for (const auto& p : get().m_string_enums) {
for (const auto& p : get().m_string_enums) if (p.second == e) {
{
if (p.second == e)
{
return p.first; return p.first;
} }
} }
@ -53,12 +45,9 @@ namespace ngraph
private: private:
/// Creates the mapping. /// Creates the mapping.
EnumNames(const std::string& enum_name, EnumNames(const std::string& enum_name, const std::vector<std::pair<std::string, EnumType>> string_enums)
const std::vector<std::pair<std::string, EnumType>> string_enums) : m_enum_name(enum_name),
: m_enum_name(enum_name) m_string_enums(string_enums) {}
, m_string_enums(string_enums)
{
}
/// Must be defined to returns a singleton for each supported enum class /// Must be defined to returns a singleton for each supported enum class
static EnumNames<EnumType>& get(); static EnumNames<EnumType>& get();
@ -69,16 +58,13 @@ namespace ngraph
/// Returns the enum value matching the string /// Returns the enum value matching the string
template <typename Type, typename Value> template <typename Type, typename Value>
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type as_enum(const Value& value) {
as_enum(const Value& value)
{
return EnumNames<Type>::as_enum(value); return EnumNames<Type>::as_enum(value);
} }
/// Returns the string matching the enum value /// Returns the string matching the enum value
template <typename Value> template <typename Value>
const std::string& as_string(Value value) const std::string& as_string(Value value) {
{
return EnumNames<Value>::as_string(value); return EnumNames<Value>::as_string(value);
} }
} // namespace ngraph } // namespace ngraph

View File

@ -5,12 +5,10 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <ngraph/ngraph_visibility.hpp>
#include <string> #include <string>
#include <ngraph/ngraph_visibility.hpp> namespace ngraph {
namespace ngraph
{
/// \brief Get the names environment variable as a string. /// \brief Get the names environment variable as a string.
/// \param env_var The string name of the environment variable to get. /// \param env_var The string name of the environment variable to get.
/// \return Returns string by value or an empty string if the environment /// \return Returns string by value or an empty string if the environment

View File

@ -12,14 +12,12 @@
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type_traits.hpp" #include "ngraph/type/element_type_traits.hpp"
namespace ngraph namespace ngraph {
{
/// \brief Execute handlers on a subgraph to compute values /// \brief Execute handlers on a subgraph to compute values
/// ///
/// ///
template <typename V> template <typename V>
class Evaluator class Evaluator {
{
public: public:
/// \brief values we compute for outputs /// \brief values we compute for outputs
using value_map = std::map<RawNodeOutput, V>; using value_map = std::map<RawNodeOutput, V>;
@ -41,37 +39,40 @@ namespace ngraph
/// Evaluator::get_value_map(). /// Evaluator::get_value_map().
/// ///
/// \param Handlers for ops. Pairs of Node::type_info_t and handler functions. /// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
Evaluator(const op_handler_map& handlers, value_map& values) Evaluator(const op_handler_map& handlers, value_map& values) : m_handlers(handlers), m_value_map(values) {}
: m_handlers(handlers)
, m_value_map(values)
{
}
/// \brief Retrieves the value_map, which holds all Output<Node> value associations. /// \brief Retrieves the value_map, which holds all Output<Node> value associations.
value_map& get_value_map() { return m_value_map; } value_map& get_value_map() {
const value_map& get_value_map() const { return m_value_map; } return m_value_map;
}
const value_map& get_value_map() const {
return m_value_map;
}
/// \brief If set, handles all ops /// \brief If set, handles all ops
const op_handler& get_univeral_handler() const { return m_universal_handler; } const op_handler& get_univeral_handler() const {
return m_universal_handler;
}
/// \brief If set, handles all ops not in the handlers /// \brief If set, handles all ops not in the handlers
const op_handler& get_default_handler() const { return m_default_handler; } const op_handler& get_default_handler() const {
return m_default_handler;
}
/// \brief If set, handles all ops /// \brief If set, handles all ops
void set_univeral_handler(const op_handler& handler) { m_universal_handler = handler; } void set_univeral_handler(const op_handler& handler) {
m_universal_handler = handler;
}
/// \brief If set, handles all ops not in the handlers /// \brief If set, handles all ops not in the handlers
void set_default_handler(const op_handler& handler) { m_default_handler = handler; } void set_default_handler(const op_handler& handler) {
m_default_handler = handler;
}
protected: protected:
op_handler get_handler(Node* node) op_handler get_handler(Node* node) {
{
op_handler handler = m_universal_handler; op_handler handler = m_universal_handler;
if (!handler) if (!handler) {
{
auto it = m_handlers.find(node->get_type_info()); auto it = m_handlers.find(node->get_type_info());
if (it == m_handlers.end()) if (it == m_handlers.end()) {
{
handler = m_default_handler; handler = m_default_handler;
} } else {
else
{
handler = it->second; handler = it->second;
} }
} }
@ -83,56 +84,39 @@ namespace ngraph
using InstStack = std::stack<InstPtr>; using InstStack = std::stack<InstPtr>;
/// \brief Intstructions for evaluations state machine /// \brief Intstructions for evaluations state machine
class Inst class Inst {
{
protected: protected:
Inst(Node* node) Inst(Node* node) : m_node(node) {}
: m_node(node)
{
}
public: public:
virtual ~Inst() {} virtual ~Inst() {}
virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0; virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
Node* get_node() { return m_node; } Node* get_node() {
return m_node;
}
protected: protected:
Node* m_node; Node* m_node;
}; };
/// \brief Ensure value has been analyzed /// \brief Ensure value has been analyzed
class ValueInst : public Inst class ValueInst : public Inst {
{
public: public:
ValueInst(const Output<Node>& value) ValueInst(const Output<Node>& value) : Inst(value.get_node()), m_index(value.get_index()) {}
: Inst(value.get_node())
, m_index(value.get_index())
{
}
ValueInst(const RawNodeOutput& value) ValueInst(const RawNodeOutput& value) : Inst(value.node), m_index(value.index) {}
: Inst(value.node)
, m_index(value.index)
{
}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
{
// Request to analyze this value if we can // Request to analyze this value if we can
if (auto handler = evaluator.get_handler(node)) if (auto handler = evaluator.get_handler(node)) {
{
// Ensure the inputs are processed and then execute the op handler // Ensure the inputs are processed and then execute the op handler
inst_stack.push(InstPtr(new ExecuteInst(node, handler))); inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
for (auto v : node->input_values()) for (auto v : node->input_values()) {
{
inst_stack.push(InstPtr(new ValueInst(v))); inst_stack.push(InstPtr(new ValueInst(v)));
} }
} } else {
else
{
// We don't know how to handle this op, so mark the outputs as unknown // We don't know how to handle this op, so mark the outputs as unknown
for (auto output : node->outputs()) for (auto output : node->outputs()) {
{
evaluator.get_value_map()[output] = V(); evaluator.get_value_map()[output] = V();
} }
} }
@ -143,27 +127,19 @@ namespace ngraph
}; };
/// \brief All arguments have been handled; execute the node handler /// \brief All arguments have been handled; execute the node handler
class ExecuteInst : public Inst class ExecuteInst : public Inst {
{
public: public:
ExecuteInst(Node* node, op_handler& handler) ExecuteInst(Node* node, op_handler& handler) : Inst(node), m_handler(handler) {}
: Inst(node)
, m_handler(handler)
{
}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
{
// Request to execute the handleer. Pass what we know about the inputs to the // Request to execute the handleer. Pass what we know about the inputs to the
// handler and associate the results with the outputs // handler and associate the results with the outputs
std::vector<V> inputs; std::vector<V> inputs;
for (auto v : node->input_values()) for (auto v : node->input_values()) {
{
inputs.push_back(evaluator.get_value_map().at(v)); inputs.push_back(evaluator.get_value_map().at(v));
} }
std::vector<V> outputs = m_handler(node, inputs); std::vector<V> outputs = m_handler(node, inputs);
for (size_t i = 0; i < outputs.size(); ++i) for (size_t i = 0; i < outputs.size(); ++i) {
{
evaluator.get_value_map()[node->output(i)] = outputs[i]; evaluator.get_value_map()[node->output(i)] = outputs[i];
} }
} }
@ -174,18 +150,15 @@ namespace ngraph
public: public:
/// \brief Determine information about value /// \brief Determine information about value
V evaluate(const Output<Node>& value) V evaluate(const Output<Node>& value) {
{
InstStack inst_stack; InstStack inst_stack;
inst_stack.push(InstPtr(new ValueInst(value))); inst_stack.push(InstPtr(new ValueInst(value)));
while (!inst_stack.empty()) while (!inst_stack.empty()) {
{
InstPtr inst; InstPtr inst;
std::swap(inst_stack.top(), inst); std::swap(inst_stack.top(), inst);
inst_stack.pop(); inst_stack.pop();
auto node = inst->get_node(); auto node = inst->get_node();
if (m_value_map.find(node->output(0)) != m_value_map.end()) if (m_value_map.find(node->output(0)) != m_value_map.end()) {
{
// Already computed // Already computed
continue; continue;
} }

View File

@ -4,39 +4,23 @@
#pragma once #pragma once
#include <ngraph/ngraph_visibility.hpp>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <ngraph/ngraph_visibility.hpp> namespace ngraph {
namespace ngraph
{
/// Base error for ngraph runtime errors. /// Base error for ngraph runtime errors.
class NGRAPH_API ngraph_error : public std::runtime_error class NGRAPH_API ngraph_error : public std::runtime_error {
{
public: public:
explicit ngraph_error(const std::string& what_arg) explicit ngraph_error(const std::string& what_arg) : std::runtime_error(what_arg) {}
: std::runtime_error(what_arg)
{
}
explicit ngraph_error(const char* what_arg) explicit ngraph_error(const char* what_arg) : std::runtime_error(what_arg) {}
: std::runtime_error(what_arg)
{
}
explicit ngraph_error(const std::stringstream& what_arg) explicit ngraph_error(const std::stringstream& what_arg) : std::runtime_error(what_arg.str()) {}
: std::runtime_error(what_arg.str())
{
}
}; };
class NGRAPH_API unsupported_op : public std::runtime_error class NGRAPH_API unsupported_op : public std::runtime_error {
{
public: public:
unsupported_op(const std::string& what_arg) unsupported_op(const std::string& what_arg) : std::runtime_error(what_arg) {}
: std::runtime_error(what_arg)
{
}
}; };
} // namespace ngraph } // namespace ngraph

View File

@ -10,63 +10,56 @@
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph {
{
NGRAPH_API std::mutex& get_registry_mutex(); NGRAPH_API std::mutex& get_registry_mutex();
/// \brief Registry of factories that can construct objects derived from BASE_TYPE /// \brief Registry of factories that can construct objects derived from BASE_TYPE
template <typename BASE_TYPE> template <typename BASE_TYPE>
class FactoryRegistry class FactoryRegistry {
{
public: public:
using Factory = std::function<BASE_TYPE*()>; using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>; using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed. // \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
static Factory get_default_factory() static Factory get_default_factory() {
{ return []() {
return []() { return new DERIVED_TYPE(); }; return new DERIVED_TYPE();
};
} }
/// \brief Register a custom factory for type_info /// \brief Register a custom factory for type_info
void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory) void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory) {
{
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[type_info] = factory; m_factory_map[type_info] = factory;
} }
/// \brief Register a custom factory for DERIVED_TYPE /// \brief Register a custom factory for DERIVED_TYPE
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
void register_factory(Factory factory) void register_factory(Factory factory) {
{
register_factory(DERIVED_TYPE::type_info, factory); register_factory(DERIVED_TYPE::type_info, factory);
} }
/// \brief Register the defualt constructor factory for DERIVED_TYPE /// \brief Register the defualt constructor factory for DERIVED_TYPE
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
void register_factory() void register_factory() {
{
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>()); register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
} }
/// \brief Check to see if a factory is registered /// \brief Check to see if a factory is registered
bool has_factory(const typename BASE_TYPE::type_info_t& info) bool has_factory(const typename BASE_TYPE::type_info_t& info) {
{
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end(); return m_factory_map.find(info) != m_factory_map.end();
} }
/// \brief Check to see if DERIVED_TYPE has a registered factory /// \brief Check to see if DERIVED_TYPE has a registered factory
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
bool has_factory() bool has_factory() {
{
return has_factory(DERIVED_TYPE::type_info); return has_factory(DERIVED_TYPE::type_info);
} }
/// \brief Create an instance for type_info /// \brief Create an instance for type_info
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const {
{
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info); auto it = m_factory_map.find(type_info);
return it == m_factory_map.end() ? nullptr : it->second(); return it == m_factory_map.end() ? nullptr : it->second();
@ -74,8 +67,7 @@ namespace ngraph
/// \brief Create an instance using factory for DERIVED_TYPE /// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
BASE_TYPE* create() const BASE_TYPE* create() const {
{
return create(DERIVED_TYPE::type_info); return create(DERIVED_TYPE::type_info);
} }

View File

@ -8,37 +8,32 @@
#include "ngraph/attribute_visitor.hpp" #include "ngraph/attribute_visitor.hpp"
#include "ngraph/factory.hpp" #include "ngraph/factory.hpp"
namespace ngraph namespace ngraph {
{
template <typename BASE_TYPE> template <typename BASE_TYPE>
class FactoryAttributeAdapter : public VisitorAdapter class FactoryAttributeAdapter : public VisitorAdapter {
{
public: public:
FactoryAttributeAdapter(std::shared_ptr<BASE_TYPE>& ref) FactoryAttributeAdapter(std::shared_ptr<BASE_TYPE>& ref) : m_ref(ref) {}
: m_ref(ref)
{
}
/// \brief Hook for extra processing before other attributes /// \brief Hook for extra processing before other attributes
virtual bool on_start(AttributeVisitor& /* visitor */) { return true; } virtual bool on_start(AttributeVisitor& /* visitor */) {
return true;
}
/// \brief Hook for extra processing after other attributes /// \brief Hook for extra processing after other attributes
virtual bool on_finish(AttributeVisitor& /* visitor */) { return true; } virtual bool on_finish(AttributeVisitor& /* visitor */) {
bool visit_attributes(AttributeVisitor& visitor) override return true;
{ }
if (on_start(visitor)) bool visit_attributes(AttributeVisitor& visitor) override {
{ if (on_start(visitor)) {
std::string type_info_name; std::string type_info_name;
uint64_t type_info_version; uint64_t type_info_version;
if (m_ref) if (m_ref) {
{
auto& type_info = m_ref->get_type_info(); auto& type_info = m_ref->get_type_info();
type_info_name = type_info.name; type_info_name = type_info.name;
type_info_version = type_info.version; type_info_version = type_info.version;
} }
visitor.on_attribute("name", type_info_name); visitor.on_attribute("name", type_info_name);
visitor.on_attribute("version", type_info_version); visitor.on_attribute("version", type_info_version);
if (m_ref) if (m_ref) {
{
visitor.start_structure("value"); visitor.start_structure("value");
m_ref->visit_attributes(visitor); m_ref->visit_attributes(visitor);
visitor.finish_structure(); visitor.finish_structure();

View File

@ -5,15 +5,12 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <ngraph/ngraph_visibility.hpp>
#include <string> #include <string>
#include <vector> #include <vector>
#include <ngraph/ngraph_visibility.hpp> namespace ngraph {
namespace file_util {
namespace ngraph
{
namespace file_util
{
/// \brief Returns the name with extension for a given path /// \brief Returns the name with extension for a given path
/// \param path The path to the output file /// \param path The path to the output file
NGRAPH_API NGRAPH_API
@ -37,10 +34,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3); std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3);
NGRAPH_API NGRAPH_API
std::string path_join(const std::string& s1, std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3, const std::string& s4);
const std::string& s2,
const std::string& s3,
const std::string& s4);
/// \brief Iterate through files and optionally directories. Symbolic links are skipped. /// \brief Iterate through files and optionally directories. Symbolic links are skipped.
/// \param path The path to iterate over /// \param path The path to iterate over

View File

@ -20,29 +20,21 @@
#include "ngraph/op/sink.hpp" #include "ngraph/op/sink.hpp"
#include "ngraph/op/util/variable.hpp" #include "ngraph/op/util/variable.hpp"
namespace ngraph namespace ngraph {
{
/// A user-defined function. /// A user-defined function.
class NGRAPH_API Function class NGRAPH_API Function {
{
public: public:
static constexpr DiscreteTypeInfo type_info{"Function", 0}; static constexpr DiscreteTypeInfo type_info{"Function", 0};
const DiscreteTypeInfo& get_type_info() const { return type_info; } const DiscreteTypeInfo& get_type_info() const {
Function(const NodeVector& results, return type_info;
const ParameterVector& parameters, }
const std::string& name = ""); Function(const NodeVector& results, const ParameterVector& parameters, const std::string& name = "");
Function(const OutputVector& results, Function(const OutputVector& results, const ParameterVector& parameters, const std::string& name = "");
const ParameterVector& parameters,
const std::string& name = "");
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result, const ParameterVector& parameters, const std::string& name = "");
const ParameterVector& parameters,
const std::string& name = "");
Function(const ResultVector& results, Function(const ResultVector& results, const ParameterVector& parameters, const std::string& name = "");
const ParameterVector& parameters,
const std::string& name = "");
Function(const ResultVector& results, Function(const ResultVector& results,
const SinkVector& sinks, const SinkVector& sinks,
@ -82,9 +74,7 @@ namespace ngraph
/// Constructs a Function. Lists of parameters and variables will be generated automatically /// Constructs a Function. Lists of parameters and variables will be generated automatically
/// based on traversing the graph from the results and the sinks. /// based on traversing the graph from the results and the sinks.
Function(const OutputVector& results, Function(const OutputVector& results, const SinkVector& sinks, const std::string& name = "");
const SinkVector& sinks,
const std::string& name = "");
virtual ~Function() = default; virtual ~Function() = default;
/// Return the number of outputs for this function. /// Return the number of outputs for this function.
@ -147,19 +137,22 @@ namespace ngraph
/// ///
/// \param parameter_index The index of the parameter to replace. /// \param parameter_index The index of the parameter to replace.
/// \param parameter The parameter to substitute for the `parameter_index`th parameter. /// \param parameter The parameter to substitute for the `parameter_index`th parameter.
void replace_parameter(size_t parameter_index, void replace_parameter(size_t parameter_index, const std::shared_ptr<op::Parameter>& parameter);
const std::shared_ptr<op::Parameter>& parameter);
using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>( using topological_sort_t =
const std::vector<std::shared_ptr<Node>>& root_nodes)>; std::function<std::vector<std::shared_ptr<Node>>(const std::vector<std::shared_ptr<Node>>& root_nodes)>;
void set_topological_sort(topological_sort_t); void set_topological_sort(topological_sort_t);
virtual bool visit_attributes(AttributeVisitor& visitor); virtual bool visit_attributes(AttributeVisitor& visitor);
/// Return the function parameters /// Return the function parameters
const ParameterVector& get_parameters() const { return m_parameters; }; const ParameterVector& get_parameters() const {
return m_parameters;
};
/// Return a list of function's outputs /// Return a list of function's outputs
const ResultVector& get_results() const { return m_results; }; const ResultVector& get_results() const {
return m_results;
};
/// Index for parameter, or -1 /// Index for parameter, or -1
int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const; int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
@ -176,7 +169,9 @@ namespace ngraph
EvaluationContext evaluation_context = EvaluationContext()) const; EvaluationContext evaluation_context = EvaluationContext()) const;
/// \brief Return a list of function's sinks. /// \brief Return a list of function's sinks.
const SinkVector& get_sinks() const { return m_sinks; } const SinkVector& get_sinks() const {
return m_sinks;
}
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done /// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
/// manually after all changes. /// manually after all changes.
/// \param sinks new sink nodes /// \param sinks new sink nodes
@ -234,7 +229,9 @@ namespace ngraph
void remove_variable(const VariablePtr& variable); void remove_variable(const VariablePtr& variable);
/// \brief Return a list of function's variables. /// \brief Return a list of function's variables.
const VariableVector& get_variables() const { return m_variables; } const VariableVector& get_variables() const {
return m_variables;
}
/// \brief Return a variable by specified variable_id. /// \brief Return a variable by specified variable_id.
VariablePtr get_variable_by_id(const std::string& variable_id) const; VariablePtr get_variable_by_id(const std::string& variable_id) const;
@ -268,17 +265,13 @@ namespace ngraph
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public DirectValueAccessor<std::shared_ptr<Function>> {
: public DirectValueAccessor<std::shared_ptr<Function>>
{
public: public:
AttributeAdapter(std::shared_ptr<Function>& value) AttributeAdapter(std::shared_ptr<Function>& value) : DirectValueAccessor<std::shared_ptr<Function>>(value) {}
: DirectValueAccessor<std::shared_ptr<Function>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>", static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>", 0};
0}; const DiscreteTypeInfo& get_type_info() const override {
const DiscreteTypeInfo& get_type_info() const override { return type_info; } return type_info;
}
}; };
} // namespace ngraph } // namespace ngraph

View File

@ -18,25 +18,20 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace ngraph namespace ngraph {
{ namespace descriptor {
namespace descriptor
{
class Input; class Input;
class Output; class Output;
} // namespace descriptor } // namespace descriptor
namespace op namespace op {
{ namespace v0 {
namespace v0
{
class Parameter; class Parameter;
} }
} // namespace op } // namespace op
NGRAPH_API NGRAPH_API
void traverse_nodes(const std::shared_ptr<const Function> p, void traverse_nodes(const std::shared_ptr<const Function> p, std::function<void(std::shared_ptr<Node>)> f);
std::function<void(std::shared_ptr<Node>)> f);
NGRAPH_API NGRAPH_API
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f); void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
@ -228,62 +223,49 @@ namespace ngraph
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in /// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
/// `body_replacement_map`, behavior is unspecified. /// `body_replacement_map`, behavior is unspecified.
NGRAPH_API NGRAPH_API
void replace_nodes( void replace_nodes(const std::shared_ptr<Function>& f,
const std::shared_ptr<Function>& f, const std::unordered_map<std::shared_ptr<op::v0::Parameter>, std::shared_ptr<op::v0::Parameter>>&
const std::unordered_map<std::shared_ptr<op::v0::Parameter>, parameter_replacement_map,
std::shared_ptr<op::v0::Parameter>>& parameter_replacement_map, const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>& body_replacement_map);
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>&
body_replacement_map);
NGRAPH_API NGRAPH_API
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// Topological sort of nodes needed to compute root_nodes /// Topological sort of nodes needed to compute root_nodes
template <typename T> template <typename T>
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes) std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes) {
{
std::stack<Node*, std::vector<Node*>> nodes_to_do; std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done; std::unordered_set<Node*> nodes_done;
std::vector<std::shared_ptr<Node>> result; std::vector<std::shared_ptr<Node>> result;
for (auto& node : root_nodes) for (auto& node : root_nodes) {
{
nodes_to_do.push(node.get()); nodes_to_do.push(node.get());
} }
while (nodes_to_do.size() > 0) while (nodes_to_do.size() > 0) {
{
Node* node = nodes_to_do.top(); Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0) if (nodes_done.count(node) == 0) {
{
bool can_add = true; bool can_add = true;
size_t arg_count = node->get_input_size(); size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i) for (size_t i = 0; i < arg_count; ++i) {
{
Node* dep = node->get_input_node_ptr(arg_count - i - 1); Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0) if (nodes_done.count(dep) == 0) {
{
can_add = false; can_add = false;
nodes_to_do.push(dep); nodes_to_do.push(dep);
} }
} }
for (auto& depptr : node->get_control_dependencies()) for (auto& depptr : node->get_control_dependencies()) {
{
Node* dep = depptr.get(); Node* dep = depptr.get();
if (nodes_done.count(dep) == 0) if (nodes_done.count(dep) == 0) {
{
can_add = false; can_add = false;
nodes_to_do.push(dep); nodes_to_do.push(dep);
} }
} }
if (can_add) if (can_add) {
{
result.push_back(node->shared_from_this()); result.push_back(node->shared_from_this());
nodes_to_do.pop(); nodes_to_do.pop();
nodes_done.insert(node); nodes_done.insert(node);
} }
} } else {
else
{
nodes_to_do.pop(); nodes_to_do.pop();
} }
} }
@ -292,49 +274,39 @@ namespace ngraph
/// Topological sort of just nodes /// Topological sort of just nodes
template <typename T> template <typename T>
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes) std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes) {
{
std::stack<Node*, std::vector<Node*>> nodes_to_do; std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done; std::unordered_set<Node*> nodes_done;
std::unordered_set<Node*> nodes_to_emit; std::unordered_set<Node*> nodes_to_emit;
std::vector<std::shared_ptr<Node>> result; std::vector<std::shared_ptr<Node>> result;
for (auto& node : nodes) for (auto& node : nodes) {
{
nodes_to_emit.insert(node.get()); nodes_to_emit.insert(node.get());
nodes_to_do.push(node.get()); nodes_to_do.push(node.get());
} }
// NB: Some centos versions implement std::list::size() by counting elements // NB: Some centos versions implement std::list::size() by counting elements
size_t nodes_remaining = nodes_to_emit.size(); size_t nodes_remaining = nodes_to_emit.size();
while (nodes_to_do.size() > 0 && nodes_remaining > 0) while (nodes_to_do.size() > 0 && nodes_remaining > 0) {
{
Node* node = nodes_to_do.top(); Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0) if (nodes_done.count(node) == 0) {
{
bool can_add = true; bool can_add = true;
size_t arg_count = node->get_input_size(); size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i) for (size_t i = 0; i < arg_count; ++i) {
{
Node* dep = node->get_input_node_ptr(arg_count - i - 1); Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0) if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0) {
{
can_add = false; can_add = false;
nodes_to_do.push(dep); nodes_to_do.push(dep);
} }
} }
for (auto& depptr : node->get_control_dependencies()) for (auto& depptr : node->get_control_dependencies()) {
{
Node* dep = depptr.get(); Node* dep = depptr.get();
if (nodes_done.count(dep) == 0) if (nodes_done.count(dep) == 0) {
{
can_add = false; can_add = false;
nodes_to_do.push(dep); nodes_to_do.push(dep);
} }
} }
if (can_add) if (can_add) {
{ if (nodes_to_emit.count(node) != 0) {
if (nodes_to_emit.count(node) != 0)
{
result.push_back(node->shared_from_this()); result.push_back(node->shared_from_this());
nodes_remaining--; nodes_remaining--;
} }
@ -343,8 +315,7 @@ namespace ngraph
} }
} }
else else {
{
nodes_to_do.pop(); nodes_to_do.pop();
} }
} }
@ -352,10 +323,8 @@ namespace ngraph
} }
template <typename T> template <typename T>
void validate_nodes_and_infer_types(const T& nodes) void validate_nodes_and_infer_types(const T& nodes) {
{ for (auto& node : subgraph_topological_sort(nodes)) {
for (auto& node : subgraph_topological_sort(nodes))
{
node->revalidate_and_infer_types(); node->revalidate_and_infer_types();
} }
} }
@ -365,38 +334,35 @@ namespace ngraph
bool is_post_dominated(Node* X, Node* Y); bool is_post_dominated(Node* X, Node* Y);
NGRAPH_API NGRAPH_API
bool is_equal_to_const_value(const std::string& const_value, bool is_equal_to_const_value(const std::string& const_value, const Output<Node>& reduce_constant);
const Output<Node>& reduce_constant);
// input nodes are cloned and returned // input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes // NodeMap output (by reference) fully maps input and cloned nodes
NGRAPH_API NGRAPH_API
std::vector<std::shared_ptr<ngraph::Node>> std::vector<std::shared_ptr<ngraph::Node>> clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map); NodeMap& node_map);
// input nodes are cloned and returned // input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes // NodeMap output (by reference) fully maps input and cloned nodes
NGRAPH_API NGRAPH_API
std::list<std::shared_ptr<ngraph::Node>> std::list<std::shared_ptr<ngraph::Node>> clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& node_map); RawNodeOutputMap& node_map);
// input function is cloned and returned // input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops // NodeMap output (by reference) fully maps input and cloned function ops
NGRAPH_API NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func, std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func, NodeMap& node_map);
NodeMap& node_map);
// input function is cloned and returned // input function is cloned and returned
NGRAPH_API NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func); std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
NGRAPH_API NGRAPH_API
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>> std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>> insert_result_parameter_split(
insert_result_parameter_split(const std::shared_ptr<Node>& src_node, const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node); const std::shared_ptr<Node>& dst_node);
NGRAPH_API NGRAPH_API
@ -408,9 +374,7 @@ namespace ngraph
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape); std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape);
NGRAPH_API NGRAPH_API
std::shared_ptr<Node> make_constant_from_string(std::string val, std::shared_ptr<Node> make_constant_from_string(std::string val, const element::Type& element_type, const Shape& shape);
const element::Type& element_type,
const Shape& shape);
NGRAPH_API NGRAPH_API
bool is_zero(const Output<Node>& reduce_constant); bool is_zero(const Output<Node>& reduce_constant);
@ -454,8 +418,7 @@ namespace ngraph
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks); bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
NGRAPH_API NGRAPH_API
void plot_graph( void plot_graph(std::shared_ptr<Function> f,
std::shared_ptr<Function> f,
const std::string& filename, const std::string& filename,
std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr); std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr);
@ -472,9 +435,7 @@ namespace ngraph
/// going forward. /// going forward.
/// It returns true if a cycle is found and the first cycle encountered. /// It returns true if a cycle is found and the first cycle encountered.
NGRAPH_API NGRAPH_API
bool check_for_cycles(const ngraph::Function* func, bool check_for_cycles(const ngraph::Function* func, ngraph::NodeVector& cycle_nodes, bool& is_bkwd_cycle);
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle);
NGRAPH_API NGRAPH_API
bool replace_output_update_name(Output<Node> node, const Output<Node>& node_input); bool replace_output_update_name(Output<Node> node, const Output<Node>& node_input);

View File

@ -12,8 +12,7 @@
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph {
{
/// \brief Interval arithmetic /// \brief Interval arithmetic
/// ///
/// An interval is the set of integers from m_min_val through m_max_val. /// An interval is the set of integers from m_min_val through m_max_val.
@ -21,8 +20,7 @@ namespace ngraph
/// addition, subtraction, or multiplication of intervals is the smallest interval /// addition, subtraction, or multiplication of intervals is the smallest interval
/// containing the sums, differences, or products of elements of the two intervals. An empty /// containing the sums, differences, or products of elements of the two intervals. An empty
/// interval is canonicalized to [s_max, s_max]. /// interval is canonicalized to [s_max, s_max].
class NGRAPH_API Interval class NGRAPH_API Interval {
{
public: public:
using value_type = std::int64_t; using value_type = std::int64_t;
using size_type = std::uint64_t; using size_type = std::uint64_t;
@ -41,26 +39,36 @@ namespace ngraph
Interval& operator=(const Interval& interval) = default; Interval& operator=(const Interval& interval) = default;
/// \brief The number of elements in the interval. Zero if max < min. /// \brief The number of elements in the interval. Zero if max < min.
size_type size() const size_type size() const {
{ if (m_max_val == s_max) {
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max; return m_min_val == s_max ? 0 : s_max;
} }
return m_max_val - m_min_val + 1; return m_max_val - m_min_val + 1;
} }
/// \brief Returns true if the interval has no elements /// \brief Returns true if the interval has no elements
bool empty() const { return m_min_val == s_max; } bool empty() const {
return m_min_val == s_max;
}
/// \brief the inclusive lower bound of the interval /// \brief the inclusive lower bound of the interval
value_type get_min_val() const { return m_min_val; } value_type get_min_val() const {
return m_min_val;
}
/// \brief Set the inclusive lower bound of the interval /// \brief Set the inclusive lower bound of the interval
void set_min_val(value_type val) { m_min_val = val; } void set_min_val(value_type val) {
m_min_val = val;
}
/// \brief the inclusive upper bound of the interval /// \brief the inclusive upper bound of the interval
value_type get_max_val() const { return m_max_val; } value_type get_max_val() const {
return m_max_val;
}
/// \brief Set the inclusive upper bound of the interval /// \brief Set the inclusive upper bound of the interval
void set_max_val(value_type val) { m_max_val = val; } void set_max_val(value_type val) {
m_max_val = val;
}
/// \brief True if the upper bound is finite /// \brief True if the upper bound is finite
bool has_upper_bound() const { return m_max_val != s_max; } bool has_upper_bound() const {
return m_max_val != s_max;
}
/// \brief True if min and max bounds match /// \brief True if min and max bounds match
bool operator==(const Interval& interval) const; bool operator==(const Interval& interval) const;
bool operator!=(const Interval& interval) const; bool operator!=(const Interval& interval) const;
@ -91,7 +99,9 @@ namespace ngraph
Interval& operator&=(const Interval& interval); Interval& operator&=(const Interval& interval);
/// \brief True if this interval includes value /// \brief True if this interval includes value
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; } bool contains(value_type value) const {
return m_min_val <= value && value <= m_max_val;
}
/// \brief True if this interval includes all the values in interval /// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const; bool contains(const Interval& interval) const;

View File

@ -16,80 +16,67 @@
# include <sys/time.h> # include <sys/time.h>
# include <unistd.h> # include <unistd.h>
#endif #endif
#include <ngraph/ngraph_visibility.hpp>
#include <vector> #include <vector>
#include <ngraph/ngraph_visibility.hpp> namespace ngraph {
class ConstString {
namespace ngraph
{
class ConstString
{
public: public:
template <size_t SIZE> template <size_t SIZE>
constexpr ConstString(const char (&p)[SIZE]) constexpr ConstString(const char (&p)[SIZE]) : m_string(p),
: m_string(p) m_size(SIZE) {}
, m_size(SIZE)
{
}
constexpr char operator[](size_t i) const constexpr char operator[](size_t i) const {
{
return i < m_size ? m_string[i] : throw std::out_of_range(""); return i < m_size ? m_string[i] : throw std::out_of_range("");
} }
constexpr const char* get_ptr(size_t offset) const constexpr const char* get_ptr(size_t offset) const {
{
return offset < m_size ? &m_string[offset] : m_string; return offset < m_size ? &m_string[offset] : m_string;
} }
constexpr size_t size() const { return m_size; } constexpr size_t size() const {
return m_size;
}
private: private:
const char* m_string; const char* m_string;
size_t m_size; size_t m_size;
}; };
constexpr const char* find_last(ConstString s, size_t offset, char ch) constexpr const char* find_last(ConstString s, size_t offset, char ch) {
{ return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
return offset == 0
? s.get_ptr(0)
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
} }
constexpr const char* find_last(ConstString s, char ch) constexpr const char* find_last(ConstString s, char ch) {
{
return find_last(s, s.size() - 1, ch); return find_last(s, s.size() - 1, ch);
} }
constexpr const char* get_file_name(ConstString s) { return find_last(s, '/'); } constexpr const char* get_file_name(ConstString s) {
constexpr const char* trim_file_name(ConstString root, ConstString s) return find_last(s, '/');
{ }
constexpr const char* trim_file_name(ConstString root, ConstString s) {
return s.get_ptr(root.size()); return s.get_ptr(root.size());
} }
enum class LOG_TYPE enum class LOG_TYPE {
{
_LOG_TYPE_ERROR, _LOG_TYPE_ERROR,
_LOG_TYPE_WARNING, _LOG_TYPE_WARNING,
_LOG_TYPE_INFO, _LOG_TYPE_INFO,
_LOG_TYPE_DEBUG, _LOG_TYPE_DEBUG,
}; };
class NGRAPH_API LogHelper class NGRAPH_API LogHelper {
{
public: public:
LogHelper(LOG_TYPE, LogHelper(LOG_TYPE, const char* file, int line, std::function<void(const std::string&)> m_handler_func);
const char* file,
int line,
std::function<void(const std::string&)> m_handler_func);
~LogHelper(); ~LogHelper();
std::ostream& stream() { return m_stream; } std::ostream& stream() {
return m_stream;
}
private: private:
std::function<void(const std::string&)> m_handler_func; std::function<void(const std::string&)> m_handler_func;
std::stringstream m_stream; std::stringstream m_stream;
}; };
class Logger class Logger {
{
friend class LogHelper; friend class LogHelper;
public: public:
@ -142,28 +129,21 @@ namespace ngraph
.stream() .stream()
#else #else
struct NullLogger struct NullLogger {};
{
};
template <typename T> template <typename T>
NullLogger&& operator<<(NullLogger&& logger, T&&) NullLogger&& operator<<(NullLogger&& logger, T&&) {
{
return std::move(logger); return std::move(logger);
} }
template <typename T> template <typename T>
NullLogger&& operator<<(NullLogger&& logger, const T&) NullLogger&& operator<<(NullLogger&& logger, const T&) {
{
return std::move(logger); return std::move(logger);
} }
inline NullLogger&& inline NullLogger&& operator<<(
operator<<(NullLogger&& logger, NullLogger&& logger,
std::basic_ostream<char, std::char_traits<char>>& (&)(std::basic_ostream< std::basic_ostream<char, std::char_traits<char>>& (&)(std::basic_ostream<char, std::char_traits<char>>&)) {
char,
std::char_traits<char>>&))
{
return std::move(logger); return std::move(logger);
} }

View File

@ -18,8 +18,7 @@
extern "C" NGRAPH_API const char* get_ngraph_version_string(); extern "C" NGRAPH_API const char* get_ngraph_version_string();
namespace ngraph namespace ngraph {
{
/// \brief Function to query parsed version information of the version of ngraph which /// \brief Function to query parsed version information of the version of ngraph which
/// contains this function. Version information strictly follows Semantic Versioning /// contains this function. Version information strictly follows Semantic Versioning
/// http://semver.org /// http://semver.org

View File

@ -28,8 +28,7 @@
# if defined __INTEL_COMPILER || defined _MSC_VER # if defined __INTEL_COMPILER || defined _MSC_VER
# define ENABLE_UNICODE_PATH_SUPPORT # define ENABLE_UNICODE_PATH_SUPPORT
# endif # endif
#elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || \ # elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || defined(__clang__)
defined(__clang__)
# define ENABLE_UNICODE_PATH_SUPPORT # define ENABLE_UNICODE_PATH_SUPPORT
# endif # endif
#endif #endif

View File

@ -37,8 +37,7 @@
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
#include "ngraph/variant.hpp" #include "ngraph/variant.hpp"
namespace ngraph namespace ngraph {
{
template <typename NodeType> template <typename NodeType>
class Input; class Input;
@ -50,8 +49,7 @@ namespace ngraph
class Function; class Function;
namespace runtime namespace runtime {
{
class HostTensor; class HostTensor;
} }
using HostTensor = runtime::HostTensor; using HostTensor = runtime::HostTensor;
@ -62,18 +60,15 @@ namespace ngraph
/// environment) for evaluating ngraph::function. /// environment) for evaluating ngraph::function.
using EvaluationContext = std::map<std::string, std::shared_ptr<Variant>>; using EvaluationContext = std::map<std::string, std::shared_ptr<Variant>>;
namespace op namespace op {
{
struct AutoBroadcastSpec; struct AutoBroadcastSpec;
namespace v0 namespace v0 {
{
class Result; class Result;
} }
} // namespace op } // namespace op
namespace pattern namespace pattern {
{
class Matcher; class Matcher;
} }
@ -82,13 +77,11 @@ namespace ngraph
NGRAPH_API NGRAPH_API
std::string node_validation_failure_loc_string(const Node* node); std::string node_validation_failure_loc_string(const Node* node);
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, size_t i);
size_t i);
NGRAPH_API NGRAPH_API
const NodeVector& check_single_output_args(const NodeVector& args); const NodeVector& check_single_output_args(const NodeVector& args);
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, size_t i);
size_t i);
NGRAPH_API NGRAPH_API
OutputVector as_output_vector(const NodeVector& args); OutputVector as_output_vector(const NodeVector& args);
@ -118,13 +111,13 @@ namespace ngraph
/// a runtime error. /// a runtime error.
#define TYPE_CASE(a) \ #define TYPE_CASE(a) \
case element::Type_t::a: rc = evaluate<element::Type_t::a> case element::Type_t::a: \
rc = evaluate<element::Type_t::a>
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// or a (possibly empty) tuple of values. /// or a (possibly empty) tuple of values.
class NGRAPH_API Node : public std::enable_shared_from_this<Node> class NGRAPH_API Node : public std::enable_shared_from_this<Node> {
{
// For access to m_outputs. // For access to m_outputs.
friend class descriptor::Input; friend class descriptor::Input;
@ -196,7 +189,9 @@ namespace ngraph
public: public:
virtual ~Node(); virtual ~Node();
virtual bool visit_attributes(AttributeVisitor&) { return false; } virtual bool visit_attributes(AttributeVisitor&) {
return false;
}
/// \returns the autobroadcasr spec /// \returns the autobroadcasr spec
virtual const op::AutoBroadcastSpec& get_autob() const; virtual const op::AutoBroadcastSpec& get_autob() const;
@ -208,8 +203,7 @@ namespace ngraph
/// \param output_values Tensors for the outputs to compute. One for each result /// \param output_values Tensors for the outputs to compute. One for each result
/// \param input_values Tensors for the inputs. One for each inputs. /// \param input_values Tensors for the inputs. One for each inputs.
/// \returns true if successful /// \returns true if successful
virtual bool evaluate(const HostTensorVector& output_values, virtual bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const;
const HostTensorVector& input_values) const;
/// \brief Evaluates the op on input_values putting results in output_values /// \brief Evaluates the op on input_values putting results in output_values
/// \param output_values Tensors for the outputs to compute. One for each result /// \param output_values Tensors for the outputs to compute. One for each result
/// \param input_values Tensors for the inputs. One for each inputs. /// \param input_values Tensors for the inputs. One for each inputs.
@ -227,12 +221,16 @@ namespace ngraph
/// ///
/// \return A vector of nodes comprising the sub-graph. The order of output /// \return A vector of nodes comprising the sub-graph. The order of output
/// tensors must match the match output tensors of the FusedOp /// tensors must match the match output tensors of the FusedOp
virtual OutputVector decompose_op() const { return OutputVector(); } virtual OutputVector decompose_op() const {
return OutputVector();
}
/// Returns the NodeTypeInfo for the node's class. /// Returns the NodeTypeInfo for the node's class.
/// During transition to type_info, returns a dummy type_info for Node if the class /// During transition to type_info, returns a dummy type_info for Node if the class
/// has not been updated yet. /// has not been updated yet.
virtual const type_info_t& get_type_info() const = 0; virtual const type_info_t& get_type_info() const = 0;
const char* get_type_name() const { return get_type_info().name; } const char* get_type_name() const {
return get_type_info().name;
}
/// Sets/replaces the arguments with new arguments. /// Sets/replaces the arguments with new arguments.
void set_arguments(const NodeVector& arguments); void set_arguments(const NodeVector& arguments);
/// Sets/replaces the arguments with new arguments. /// Sets/replaces the arguments with new arguments.
@ -240,16 +238,13 @@ namespace ngraph
/// Sets/replaces the arguments with new arguments. /// Sets/replaces the arguments with new arguments.
void set_argument(size_t position, const Output<Node>& argument); void set_argument(size_t position, const Output<Node>& argument);
void set_output_type(size_t i, void set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape);
const element::Type& element_type,
const PartialShape& pshape);
/// Sets the number of outputs /// Sets the number of outputs
void set_output_size(size_t output_size); void set_output_size(size_t output_size);
void invalidate_values(); void invalidate_values();
virtual void revalidate_and_infer_types() virtual void revalidate_and_infer_types() {
{
invalidate_values(); invalidate_values();
validate_and_infer_types(); validate_and_infer_types();
} }
@ -273,7 +268,9 @@ namespace ngraph
const std::string& get_friendly_name() const; const std::string& get_friendly_name() const;
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const {
return m_instance_id;
}
/// \brief Writes a description of a node to a stream /// \brief Writes a description of a node to a stream
/// \param os The stream; should be returned /// \param os The stream; should be returned
/// \param depth How many levels of inputs to describe /// \param depth How many levels of inputs to describe
@ -346,8 +343,7 @@ namespace ngraph
descriptor::Tensor& get_input_tensor(size_t i) const; descriptor::Tensor& get_input_tensor(size_t i) const;
/// Returns the tensor name for output i /// Returns the tensor name for output i
NGRAPH_DEPRECATED( NGRAPH_DEPRECATED("The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
"The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
const std::string& get_output_tensor_name(size_t i) const; const std::string& get_output_tensor_name(size_t i) const;
std::set<Input<Node>> get_output_target_inputs(size_t i) const; std::set<Input<Node>> get_output_target_inputs(size_t i) const;
@ -368,8 +364,7 @@ namespace ngraph
const PartialShape& get_input_partial_shape(size_t i) const; const PartialShape& get_input_partial_shape(size_t i) const;
/// Returns the tensor name for input i /// Returns the tensor name for input i
NGRAPH_DEPRECATED( NGRAPH_DEPRECATED("The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
"The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
const std::string& get_input_tensor_name(size_t i) const; const std::string& get_input_tensor_name(size_t i) const;
std::unordered_set<descriptor::Tensor*> liveness_new_list; std::unordered_set<descriptor::Tensor*> liveness_new_list;
@ -384,8 +379,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const; std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
std::shared_ptr<Node> copy_with_new_inputs( std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& inputs,
const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const; const std::vector<std::shared_ptr<Node>>& control_dependencies) const;
/// True if this and node have one output with same element type and shape /// True if this and node have one output with same element type and shape
@ -393,21 +387,22 @@ namespace ngraph
using RTMap = std::map<std::string, std::shared_ptr<Variant>>; using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& get_rt_info() { return m_rt_info; } RTMap& get_rt_info() {
const RTMap& get_rt_info() const { return m_rt_info; } return m_rt_info;
}
const RTMap& get_rt_info() const {
return m_rt_info;
}
const std::unordered_set<std::string>& get_provenance_tags() const; const std::unordered_set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag); void add_provenance_tag(const std::string& tag);
template <typename T> template <typename T>
void add_provenance_tags(T tag_set) void add_provenance_tags(T tag_set) {
{ for (auto tag : tag_set) {
for (auto tag : tag_set)
{
add_provenance_tag(tag); add_provenance_tag(tag);
} }
} }
/// \brief Adds tag_set to this node and all intermediate nodes above base /// \brief Adds tag_set to this node and all intermediate nodes above base
void add_provenance_tags_above(const OutputVector& base, void add_provenance_tags_above(const OutputVector& base, const std::unordered_set<std::string>& tag_set);
const std::unordered_set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag); void remove_provenance_tag(const std::string& tag);
/// \brief Add node to additional nodes that receive tags /// \brief Add node to additional nodes that receive tags
void add_provenance_group_member(const std::shared_ptr<Node>& node); void add_provenance_group_member(const std::shared_ptr<Node>& node);
@ -433,12 +428,18 @@ namespace ngraph
NodeVector get_users(bool check_is_used = false) const; NodeVector get_users(bool check_is_used = false) const;
/// \return Version of this node /// \return Version of this node
virtual size_t get_version() const { return get_type_info().version; } virtual size_t get_version() const {
return get_type_info().version;
}
NGRAPH_DEPRECATED("This method is deprecated and will be removed soon.") NGRAPH_DEPRECATED("This method is deprecated and will be removed soon.")
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; } virtual std::shared_ptr<Node> get_default_value() const {
return nullptr;
}
/// Use instance ids for comparison instead of memory addresses to improve determinism /// Use instance ids for comparison instead of memory addresses to improve determinism
bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; } bool operator<(const Node& other) const {
return m_instance_id < other.m_instance_id;
}
/// \return A vector containing a handle for each of this node's inputs, in order. /// \return A vector containing a handle for each of this node's inputs, in order.
// TODO: Rename to get_inputs()? // TODO: Rename to get_inputs()?
std::vector<Input<Node>> inputs(); std::vector<Input<Node>> inputs();
@ -474,12 +475,10 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs. /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
Output<const Node> output(size_t output_index) const; Output<const Node> output(size_t output_index) const;
void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations) void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations) {
{
m_op_annotations = op_annotations; m_op_annotations = op_annotations;
} }
std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const {
{
return m_op_annotations; return m_op_annotations;
} }
@ -558,23 +557,21 @@ namespace ngraph
static const ::ngraph::Node::type_info_t& get_type_info_static() static const ::ngraph::Node::type_info_t& get_type_info_static()
#define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \ #define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \
const ::ngraph::Node::type_info_t& CLASS::get_type_info() const \ const ::ngraph::Node::type_info_t& CLASS::get_type_info() const { \
{ \
return get_type_info_static(); \ return get_type_info_static(); \
} \ } \
const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static() const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static()
#define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \ #define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \
const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \ const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() { \
{ \ static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, \
static const ::ngraph::Node::type_info_t type_info_static{ \ _VERSION_INDEX, \
TYPE_NAME, _VERSION_INDEX, &PARENT_CLASS::get_type_info_static()}; \ &PARENT_CLASS::get_type_info_static()}; \
return type_info_static; \ return type_info_static; \
} \ } \
_NGRAPH_RTTI_DEFINITION_COMMON(CLASS) _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
#define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \ #define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \
const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \ const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() { \
{ \
static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \ static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \
return type_info_static; \ return type_info_static; \
} \ } \
@ -610,23 +607,14 @@ namespace ngraph
/// For convenience, TYPE_NAME and CLASS name are recommended to be the same. /// For convenience, TYPE_NAME and CLASS name are recommended to be the same.
/// ///
#define NGRAPH_RTTI_DEFINITION(...) \ #define NGRAPH_RTTI_DEFINITION(...) \
_NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR( \ _NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR(__VA_ARGS__, \
__VA_ARGS__, _NGRAPH_RTTI_DEFINITION_WITH_PARENT, _NGRAPH_RTTI_DEFINITION_NO_PARENT)( \ _NGRAPH_RTTI_DEFINITION_WITH_PARENT, \
__VA_ARGS__)) _NGRAPH_RTTI_DEFINITION_NO_PARENT)(__VA_ARGS__))
// Like an Output but with a Node* instead of a shared_ptr<Node> // Like an Output but with a Node* instead of a shared_ptr<Node>
struct RawNodeOutput struct RawNodeOutput {
{ RawNodeOutput(const Output<Node>& value) : node(value.get_node()), index(value.get_index()) {}
RawNodeOutput(const Output<Node>& value) RawNodeOutput(Node* node, size_t index) : node(node), index(index) {}
: node(value.get_node())
, index(value.get_index())
{
}
RawNodeOutput(Node* node, size_t index)
: node(node)
, index(index)
{
}
RawNodeOutput(const RawNodeOutput&) = default; RawNodeOutput(const RawNodeOutput&) = default;
RawNodeOutput() = default; RawNodeOutput() = default;
RawNodeOutput& operator=(const RawNodeOutput&) = default; RawNodeOutput& operator=(const RawNodeOutput&) = default;
@ -634,49 +622,56 @@ namespace ngraph
Node* node; Node* node;
size_t index{0}; size_t index{0};
operator Output<Node>() { return Output<Node>(node->shared_from_this(), index); } operator Output<Node>() {
bool operator==(const RawNodeOutput& other) const return Output<Node>(node->shared_from_this(), index);
{ }
bool operator==(const RawNodeOutput& other) const {
return node == other.node && index == other.index; return node == other.node && index == other.index;
} }
bool operator!=(const RawNodeOutput& other) const { return !(*this == other); } bool operator!=(const RawNodeOutput& other) const {
bool operator<(const RawNodeOutput& other) const return !(*this == other);
{ }
bool operator<(const RawNodeOutput& other) const {
return node < other.node || (node == other.node && index < other.index); return node < other.node || (node == other.node && index < other.index);
} }
bool operator>(const RawNodeOutput& other) const bool operator>(const RawNodeOutput& other) const {
{
return node > other.node || (node == other.node && index > other.index); return node > other.node || (node == other.node && index > other.index);
} }
bool operator<=(const RawNodeOutput& other) const { return !(*this > other); } bool operator<=(const RawNodeOutput& other) const {
bool operator>=(const RawNodeOutput& other) const { return !(*this < other); } return !(*this > other);
}
bool operator>=(const RawNodeOutput& other) const {
return !(*this < other);
}
}; };
/// \brief Visits a reference to a node that has been registered with the visitor. /// \brief Visits a reference to a node that has been registered with the visitor.
template <> template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<Node>> : public VisitorAdapter class NGRAPH_API AttributeAdapter<std::shared_ptr<Node>> : public VisitorAdapter {
{
public: public:
AttributeAdapter(std::shared_ptr<Node>& value); AttributeAdapter(std::shared_ptr<Node>& value);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Node>>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Node>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
protected: protected:
std::shared_ptr<Node>& m_ref; std::shared_ptr<Node>& m_ref;
}; };
template <> template <>
class NGRAPH_API AttributeAdapter<NodeVector> : public VisitorAdapter class NGRAPH_API AttributeAdapter<NodeVector> : public VisitorAdapter {
{
public: public:
AttributeAdapter(NodeVector& ref); AttributeAdapter(NodeVector& ref);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<NodeVector>", 0}; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<NodeVector>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
protected: protected:
NodeVector& m_ref; NodeVector& m_ref;
@ -684,25 +679,17 @@ namespace ngraph
using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>; using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
class NGRAPH_API NodeValidationFailure : public CheckFailure class NGRAPH_API NodeValidationFailure : public CheckFailure {
{
public: public:
NodeValidationFailure(const CheckLocInfo& check_loc_info, NodeValidationFailure(const CheckLocInfo& check_loc_info, const Node* node, const std::string& explanation)
const Node* node, : CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation) {}
const std::string& explanation)
: CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation)
{
}
}; };
} // namespace ngraph } // namespace ngraph
#define NODE_VALIDATION_CHECK(node, ...) \ #define NODE_VALIDATION_CHECK(node, ...) NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
namespace ngraph namespace ngraph {
{
template <typename T> template <typename T>
void check_new_args_count(const Node* node, T new_args) void check_new_args_count(const Node* node, T new_args) {
{
NODE_VALIDATION_CHECK(node, NODE_VALIDATION_CHECK(node,
new_args.size() == node->input_values().size(), new_args.size() == node->input_values().size(),
"clone_with_new_inputs() expected ", "clone_with_new_inputs() expected ",

View File

@ -13,22 +13,18 @@
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/variant.hpp" #include "ngraph/variant.hpp"
namespace ngraph namespace ngraph {
{
class Node; class Node;
template <typename NodeType> template <typename NodeType>
class Output; class Output;
template <typename NodeType> template <typename NodeType>
class Input class Input {};
{
};
/// \brief A handle for one of a node's inputs. /// \brief A handle for one of a node's inputs.
template <> template <>
class NGRAPH_API Input<Node> class NGRAPH_API Input<Node> {
{
public: public:
/// \brief Constructs a Input. /// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle. /// \param node Pointer to the node for the input handle.
@ -79,8 +75,7 @@ namespace ngraph
/// \brief A handle for one of a node's inputs. /// \brief A handle for one of a node's inputs.
template <> template <>
class NGRAPH_API Input<const Node> class NGRAPH_API Input<const Node> {
{
public: public:
/// \brief Constructs a Input. /// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle. /// \param node Pointer to the node for the input handle.

View File

@ -14,22 +14,18 @@
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/variant.hpp" #include "ngraph/variant.hpp"
namespace ngraph namespace ngraph {
{
class Node; class Node;
template <typename NodeType> template <typename NodeType>
class Input; class Input;
template <typename NodeType> template <typename NodeType>
class Output class Output {};
{
};
/// \brief A handle for one of a node's outputs. /// \brief A handle for one of a node's outputs.
template <> template <>
class NGRAPH_API Output<Node> class NGRAPH_API Output<Node> {
{
public: public:
/// \brief Constructs a Output. /// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle. /// \param node A pointer to the node for the output handle.
@ -46,10 +42,7 @@ namespace ngraph
/// \brief Constructs a Output, referencing the zeroth output of the node. /// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle. /// \param node A `shared_ptr` to the node for the output handle.
template <typename T> template <typename T>
Output(const std::shared_ptr<T>& node) Output(const std::shared_ptr<T>& node) : Output(node ? node->get_default_output() : Output<Node>()) {}
: Output(node ? node->get_default_output() : Output<Node>())
{
}
/// A null output /// A null output
Output() = default; Output() = default;
@ -109,8 +102,7 @@ namespace ngraph
}; };
template <> template <>
class NGRAPH_API Output<const Node> class NGRAPH_API Output<const Node> {
{
public: public:
/// \brief Constructs a Output. /// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle. /// \param node A pointer to the node for the output handle.
@ -127,10 +119,7 @@ namespace ngraph
/// \brief Constructs a Output, referencing the zeroth output of the node. /// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle. /// \param node A `shared_ptr` to the node for the output handle.
template <typename T> template <typename T>
Output(const std::shared_ptr<T>& node) Output(const std::shared_ptr<T>& node) : Output(node ? node->get_default_output() : Output<const Node>()) {}
: Output(node ? node->get_default_output() : Output<const Node>())
{
}
/// A null output /// A null output
Output() = default; Output() = default;

View File

@ -8,22 +8,22 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise absolute value operation. /// \brief Elementwise absolute value operation.
/// ///
class NGRAPH_API Abs : public util::UnaryElementwiseArithmetic class NGRAPH_API Abs : public util::UnaryElementwiseArithmetic {
{
public: public:
static constexpr NodeTypeInfo type_info{"Abs", 0}; static constexpr NodeTypeInfo type_info{"Abs", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
Abs() = default; Abs() = default;
bool visit_attributes(AttributeVisitor&) override { return true; } bool visit_attributes(AttributeVisitor&) override {
return true;
}
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
@ -33,11 +33,9 @@ namespace ngraph
/// ///
Abs(const Output<Node>& arg); Abs(const Output<Node>& arg);
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -8,19 +8,17 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise inverse cosine (arccos) operation. /// \brief Elementwise inverse cosine (arccos) operation.
/// ///
class NGRAPH_API Acos : public util::UnaryElementwiseArithmetic class NGRAPH_API Acos : public util::UnaryElementwiseArithmetic {
{
public: public:
static constexpr NodeTypeInfo type_info{"Acos", 0}; static constexpr NodeTypeInfo type_info{"Acos", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
Acos() = default; Acos() = default;
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
@ -31,11 +29,11 @@ namespace ngraph
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Acos(const Output<Node>& arg); Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor&) override { return true; } bool visit_attributes(AttributeVisitor&) override {
std::shared_ptr<Node> return true;
clone_with_new_inputs(const OutputVector& new_args) const override; }
bool evaluate(const HostTensorVector& outputs, std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -8,16 +8,12 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Elementwise inverse hyperbolic cos operation. /// \brief Elementwise inverse hyperbolic cos operation.
/// ///
class NGRAPH_API Acosh : public util::UnaryElementwiseArithmetic class NGRAPH_API Acosh : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -32,11 +28,11 @@ namespace ngraph
/// ///
Acosh(const Output<Node>& arg); Acosh(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor&) override {
bool visit_attributes(AttributeVisitor&) override { return true; } return true;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v3 } // namespace v3

View File

@ -7,16 +7,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v8 {
{
namespace v8
{
/// \brief Adaptive average pooling operation. /// \brief Adaptive average pooling operation.
/// ///
class NGRAPH_API AdaptiveAvgPool : public Op class NGRAPH_API AdaptiveAvgPool : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -35,8 +31,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v8 } // namespace v8
} // namespace op } // namespace op

View File

@ -7,16 +7,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v8 {
{
namespace v8
{
/// \brief Adaptive max pooling operation. /// \brief Adaptive max pooling operation.
/// ///
class NGRAPH_API AdaptiveMaxPool : public Op class NGRAPH_API AdaptiveMaxPool : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -33,18 +29,18 @@ namespace ngraph
/// \param index_element_type Specifies the output tensor type for indices /// \param index_element_type Specifies the output tensor type for indices
/// output /// output
/// ///
AdaptiveMaxPool( AdaptiveMaxPool(const Output<Node>& data,
const Output<Node>& data,
const Output<Node>& output_shape, const Output<Node>& output_shape,
const ngraph::element::Type& index_element_type = ngraph::element::i64); const ngraph::element::Type& index_element_type = ngraph::element::i64);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
element::Type get_index_element_type() const { return m_index_element_type; } element::Type get_index_element_type() const {
return m_index_element_type;
}
protected: protected:
ngraph::element::Type m_index_element_type = ngraph::element::i64; ngraph::element::Type m_index_element_type = ngraph::element::i64;

View File

@ -8,24 +8,17 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Elementwise addition operation. /// \brief Elementwise addition operation.
/// ///
class NGRAPH_API Add : public util::BinaryElementwiseArithmetic class NGRAPH_API Add : public util::BinaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an uninitialized addition operation /// \brief Constructs an uninitialized addition operation
Add() Add() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) {}
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
/// ///
@ -40,16 +33,13 @@ namespace ngraph
/// ///
Add(const Output<Node>& arg0, Add(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v1 } // namespace v1

View File

@ -8,16 +8,12 @@
#include "ngraph/op/util/binary_elementwise_logical.hpp" #include "ngraph/op/util/binary_elementwise_logical.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Elementwise logical-and operation. /// \brief Elementwise logical-and operation.
/// ///
class NGRAPH_API LogicalAnd : public util::BinaryElementwiseLogical class NGRAPH_API LogicalAnd : public util::BinaryElementwiseLogical {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a logical-and operation. /// \brief Constructs a logical-and operation.
@ -35,14 +31,11 @@ namespace ngraph
/// ///
LogicalAnd(const Output<Node>& arg0, LogicalAnd(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v1 } // namespace v1

View File

@ -8,19 +8,17 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise inverse sine (arcsin) operation. /// \brief Elementwise inverse sine (arcsin) operation.
/// ///
class NGRAPH_API Asin : public util::UnaryElementwiseArithmetic class NGRAPH_API Asin : public util::UnaryElementwiseArithmetic {
{
public: public:
static constexpr NodeTypeInfo type_info{"Asin", 0}; static constexpr NodeTypeInfo type_info{"Asin", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
Asin() = default; Asin() = default;
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
@ -32,11 +30,11 @@ namespace ngraph
/// ///
Asin(const Output<Node>& arg); Asin(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor&) override {
bool visit_attributes(AttributeVisitor&) override { return true; } return true;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -8,16 +8,12 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Elementwise inverse hyperbolic sin operation. /// \brief Elementwise inverse hyperbolic sin operation.
/// ///
class NGRAPH_API Asinh : public util::UnaryElementwiseArithmetic class NGRAPH_API Asinh : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -32,11 +28,11 @@ namespace ngraph
/// ///
Asinh(const Output<Node>& arg); Asinh(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor&) override {
bool visit_attributes(AttributeVisitor&) override { return true; } return true;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v3 } // namespace v3

View File

@ -8,27 +8,19 @@
#include "ngraph/op/util/variable.hpp" #include "ngraph/op/util/variable.hpp"
#include "ngraph/op/util/variable_extension.hpp" #include "ngraph/op/util/variable_extension.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op class NGRAPH_API AssignBase : public Sink, public VariableExtension {
{
class NGRAPH_API AssignBase : public Sink, public VariableExtension
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
AssignBase() = default; AssignBase() = default;
/// \brief Constructs an AssignBase operation. /// \brief Constructs an AssignBase operation.
explicit AssignBase(const OutputVector& arguments) explicit AssignBase(const OutputVector& arguments) : Sink(arguments) {}
: Sink(arguments)
{
}
}; };
namespace v3 namespace v3 {
{
/// \brief Assign operation sets an input value to the variable with `variable_id` /// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public AssignBase class NGRAPH_API Assign : public AssignBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
Assign() = default; Assign() = default;
@ -40,10 +32,11 @@ namespace ngraph
Assign(const Output<Node>& new_value, const std::string& variable_id); Assign(const Output<Node>& new_value, const std::string& variable_id);
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::string get_variable_id() const override { return m_variable_id; } std::string get_variable_id() const override {
return m_variable_id;
}
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
@ -51,11 +44,9 @@ namespace ngraph
std::string m_variable_id; std::string m_variable_id;
}; };
} // namespace v3 } // namespace v3
namespace v6 namespace v6 {
{
/// \brief Assign operation sets an input value to the variable with `variable_id` /// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public AssignBase class NGRAPH_API Assign : public AssignBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
Assign() = default; Assign() = default;
@ -70,23 +61,19 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::string get_variable_id() const override std::string get_variable_id() const override {
{ NGRAPH_CHECK(m_variable, "Variable is not initialized. Variable_id is unavailable");
NGRAPH_CHECK(m_variable,
"Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id; return m_variable->get_info().variable_id;
} }
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs, const HostTensorVector& inputs,
const EvaluationContext& evaluation_context) const override; const EvaluationContext& evaluation_context) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
bool constant_fold(OutputVector& output_values, bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
const OutputVector& inputs_values) override;
}; };
} // namespace v6 } // namespace v6
} // namespace op } // namespace op

View File

@ -8,16 +8,12 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise inverse tangent (arctan) operation. /// \brief Elementwise inverse tangent (arctan) operation.
/// ///
class NGRAPH_API Atan : public util::UnaryElementwiseArithmetic class NGRAPH_API Atan : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
@ -32,11 +28,11 @@ namespace ngraph
/// ///
Atan(const Output<Node>& arg); Atan(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor&) override {
bool visit_attributes(AttributeVisitor&) override { return true; } return true;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -8,16 +8,12 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Elementwise inverse hyperbolic tangent operation. /// \brief Elementwise inverse hyperbolic tangent operation.
/// ///
class NGRAPH_API Atanh : public util::UnaryElementwiseArithmetic class NGRAPH_API Atanh : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -32,11 +28,11 @@ namespace ngraph
/// ///
Atanh(const Output<Node>& arg); Atanh(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor&) override {
bool visit_attributes(AttributeVisitor&) override { return true; } return true;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v3 } // namespace v3

View File

@ -7,16 +7,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Batched average pooling operation. /// \brief Batched average pooling operation.
/// ///
class NGRAPH_API AvgPool : public Op class NGRAPH_API AvgPool : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -52,8 +48,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The kernel shape. /// \return The kernel shape.
const Shape& get_kernel() const; const Shape& get_kernel() const;

View File

@ -10,14 +10,10 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{ class NGRAPH_API BatchNormInference : public Op {
namespace v0
{
class NGRAPH_API BatchNormInference : public Op
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default; BatchNormInference() = default;
@ -38,10 +34,13 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const {
void set_eps_value(double epsilon) { m_epsilon = epsilon; } return m_epsilon;
std::shared_ptr<Node> }
clone_with_new_inputs(const OutputVector& new_args) const override; void set_eps_value(double epsilon) {
m_epsilon = epsilon;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
private: private:
static constexpr size_t INPUT_GAMMA = 0; static constexpr size_t INPUT_GAMMA = 0;
@ -53,10 +52,8 @@ namespace ngraph
double m_epsilon; double m_epsilon;
}; };
} // namespace v0 } // namespace v0
namespace v5 namespace v5 {
{ class NGRAPH_API BatchNormInference : public Op {
class NGRAPH_API BatchNormInference : public Op
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default; BatchNormInference() = default;
@ -77,10 +74,13 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const {
void set_eps_value(double epsilon) { m_epsilon = epsilon; } return m_epsilon;
std::shared_ptr<Node> }
clone_with_new_inputs(const OutputVector& new_args) const override; void set_eps_value(double epsilon) {
m_epsilon = epsilon;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
private: private:
static constexpr size_t INPUT_DATA = 0; static constexpr size_t INPUT_DATA = 0;

View File

@ -7,12 +7,9 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief BatchToSpace permutes data from the batch dimension of the data tensor into /// \brief BatchToSpace permutes data from the batch dimension of the data tensor into
/// spatial dimensions. /// spatial dimensions.
/// ///
@ -24,8 +21,7 @@ namespace ngraph
/// D_2 * block_shape[2] - crops_begin[2] - crops_end[2], ..., /// D_2 * block_shape[2] - crops_begin[2] - crops_end[2], ...,
/// D_{N - 1} * block_shape[N - 1] - crops_begin[N - 1] - crops_end[N - 1]` /// D_{N - 1} * block_shape[N - 1] - crops_begin[N - 1] - crops_end[N - 1]`
/// of the same type as `data` input. /// of the same type as `data` input.
class NGRAPH_API BatchToSpace : public Op class NGRAPH_API BatchToSpace : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
BatchToSpace() = default; BatchToSpace() = default;
@ -41,13 +37,11 @@ namespace ngraph
const Output<Node>& block_shape, const Output<Node>& block_shape,
const Output<Node>& crops_begin, const Output<Node>& crops_begin,
const Output<Node>& crops_end); const Output<Node>& crops_end);
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
}; };
} // namespace v1 } // namespace v1

View File

@ -8,19 +8,14 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{ class NGRAPH_API BinaryConvolution : public Op {
namespace v1
{
class NGRAPH_API BinaryConvolution : public Op
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
enum class BinaryConvolutionMode enum class BinaryConvolutionMode {
{
// Interpret input data and kernel values: 0 as -1, 1 as 1 // Interpret input data and kernel values: 0 as -1, 1 as 1
XNOR_POPCOUNT XNOR_POPCOUNT
}; };
@ -63,30 +58,57 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The strides. /// \return The strides.
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const {
void set_strides(const Strides& strides) { m_strides = strides; } return m_strides;
}
void set_strides(const Strides& strides) {
m_strides = strides;
}
/// \return The dilations. /// \return The dilations.
const Strides& get_dilations() const { return m_dilations; } const Strides& get_dilations() const {
void set_dilations(const Strides& dilations) { m_dilations = dilations; } return m_dilations;
}
void set_dilations(const Strides& dilations) {
m_dilations = dilations;
}
/// \return The padding-below sizes (possibly negative). /// \return The padding-below sizes (possibly negative).
const CoordinateDiff& get_pads_begin() const { return m_pads_begin; } const CoordinateDiff& get_pads_begin() const {
void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; } return m_pads_begin;
}
void set_pads_begin(const CoordinateDiff& pads_begin) {
m_pads_begin = pads_begin;
}
/// \return The padding-above sizes (possibly negative). /// \return The padding-above sizes (possibly negative).
const CoordinateDiff& get_pads_end() const { return m_pads_end; } const CoordinateDiff& get_pads_end() const {
void set_adding_above(const CoordinateDiff& pads_end) { m_pads_end = pads_end; } return m_pads_end;
}
void set_adding_above(const CoordinateDiff& pads_end) {
m_pads_end = pads_end;
}
/// \return The pad type for convolution. /// \return The pad type for convolution.
const PadType& get_auto_pad() const { return m_auto_pad; } const PadType& get_auto_pad() const {
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; } return m_auto_pad;
}
void set_auto_pad(const PadType& auto_pad) {
m_auto_pad = auto_pad;
}
/// \return The mode of convolution. /// \return The mode of convolution.
const BinaryConvolutionMode& get_mode() const { return m_mode; } const BinaryConvolutionMode& get_mode() const {
void set_mode(const BinaryConvolutionMode& mode) { m_mode = mode; } return m_mode;
}
void set_mode(const BinaryConvolutionMode& mode) {
m_mode = mode;
}
/// \return The pad value. /// \return The pad value.
float get_pad_value() const { return m_pad_value; } float get_pad_value() const {
void set_pad_value(float pad_value) { m_pad_value = pad_value; } return m_pad_value;
}
void set_pad_value(float pad_value) {
m_pad_value = pad_value;
}
protected: protected:
BinaryConvolutionMode mode_from_string(const std::string& mode) const; BinaryConvolutionMode mode_from_string(const std::string& mode) const;
@ -102,22 +124,20 @@ namespace ngraph
} // namespace op } // namespace op
NGRAPH_API NGRAPH_API
std::ostream& operator<<(std::ostream& s, std::ostream& operator<<(std::ostream& s, const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
template <> template <>
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode> class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode> : public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode> {
{
public: public:
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value) AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value) : EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value) {}
{
}
static constexpr DiscreteTypeInfo type_info{ static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>",
"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>", 0}; 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; } const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
}; };
} // namespace ngraph } // namespace ngraph

View File

@ -9,16 +9,12 @@
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcast_base.hpp" #include "ngraph/op/util/broadcast_base.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes. /// input as needed along the new axes.
class NGRAPH_API Broadcast : public util::BroadcastBase class NGRAPH_API Broadcast : public util::BroadcastBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -54,13 +50,13 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
// \return Broadcast Specification. // \return Broadcast Specification.
const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; } const BroadcastModeSpec& get_broadcast_spec() const {
void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec) return m_mode;
{ }
void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec) {
m_mode = broadcast_spec; m_mode = broadcast_spec;
} }
@ -68,22 +64,18 @@ namespace ngraph
/// \return true and the AxisSet if broadcast axes can be fully determined. /// \return true and the AxisSet if broadcast axes can be fully determined.
std::pair<bool, AxisSet> get_broadcast_axes() const override; std::pair<bool, AxisSet> get_broadcast_axes() const override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
private: private:
bool broadcast_evaluate(const HostTensorVector& outputs, bool broadcast_evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
const HostTensorVector& inputs) const;
}; };
} // namespace v3 } // namespace v3
namespace v1 namespace v1 {
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes. /// input as needed along the new axes.
class NGRAPH_API Broadcast : public util::BroadcastBase class NGRAPH_API Broadcast : public util::BroadcastBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -115,24 +107,22 @@ namespace ngraph
/// axes /// axes
Broadcast(const Output<Node>& arg, Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape, const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec = const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return Broadcast Specification. /// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; } const AutoBroadcastSpec& get_broadcast_spec() const {
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec) return m_broadcast_spec;
{ }
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec) {
m_broadcast_spec = broadcast_spec; m_broadcast_spec = broadcast_spec;
} }
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
protected: protected:

View File

@ -6,15 +6,11 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Operation that bucketizes the input based on boundaries /// \brief Operation that bucketizes the input based on boundaries
class NGRAPH_API Bucketize : public Op class NGRAPH_API Bucketize : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -34,17 +30,21 @@ namespace ngraph
virtual void validate_and_infer_types() override; virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
clone_with_new_inputs(const OutputVector& inputs) const override;
element::Type get_output_type() const { return m_output_type; } element::Type get_output_type() const {
void set_output_type(element::Type output_type) { m_output_type = output_type; } return m_output_type;
}
void set_output_type(element::Type output_type) {
m_output_type = output_type;
}
// Overload collision with method on Node // Overload collision with method on Node
using Node::set_output_type; using Node::set_output_type;
bool get_with_right_bound() const { return m_with_right_bound; } bool get_with_right_bound() const {
void set_with_right_bound(bool with_right_bound) return m_with_right_bound;
{ }
void set_with_right_bound(bool with_right_bound) {
m_with_right_bound = with_right_bound; m_with_right_bound = with_right_bound;
} }

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise ceiling operation. /// \brief Elementwise ceiling operation.
class NGRAPH_API Ceiling : public util::UnaryElementwiseArithmetic class NGRAPH_API Ceiling : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a ceiling operation. /// \brief Constructs a ceiling operation.
@ -24,11 +20,11 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Ceiling(const Output<Node>& arg); Ceiling(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor&) override { return true; } bool visit_attributes(AttributeVisitor&) override {
virtual std::shared_ptr<Node> return true;
clone_with_new_inputs(const OutputVector& new_args) const override; }
bool evaluate(const HostTensorVector& outputs, virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -7,19 +7,15 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Performs a clipping operation on all elements of the input node /// \brief Performs a clipping operation on all elements of the input node
/// ///
/// All input values that are outside of the <min;max> range are set to 'min' or 'max' /// All input values that are outside of the <min;max> range are set to 'min' or 'max'
/// depending on which side of the <min;max> range they are. The values that fall into /// depending on which side of the <min;max> range they are. The values that fall into
/// this range remain unchanged. /// this range remain unchanged.
class NGRAPH_API Clamp : public ngraph::op::Op class NGRAPH_API Clamp : public ngraph::op::Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -33,15 +29,17 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
double get_min() const { return m_min; } double get_min() const {
double get_max() const { return m_max; } return m_min;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; double get_max() const {
return m_max;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
private: private:

View File

@ -8,15 +8,11 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Concatenation operation. /// \brief Concatenation operation.
class NGRAPH_API Concat : public Op class NGRAPH_API Concat : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -37,20 +33,23 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The concatenation axis. /// \return The concatenation axis.
int64_t get_concatenation_axis() const { return m_concat_axis; } int64_t get_concatenation_axis() const {
void set_concatenation_axis(int64_t concatenation_axis) return m_concat_axis;
{ }
void set_concatenation_axis(int64_t concatenation_axis) {
m_concat_axis = concatenation_axis; m_concat_axis = concatenation_axis;
} }
/// \return The concatenation axis. /// \return The concatenation axis.
int64_t get_axis() const { return m_axis; } int64_t get_axis() const {
void set_axis(int64_t axis) { m_axis = axis; } return m_axis;
bool evaluate(const HostTensorVector& outputs, }
const HostTensorVector& inputs) const override; void set_axis(int64_t axis) {
m_axis = axis;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& output_values) const override; bool evaluate_lower(const HostTensorVector& output_values) const override;
bool evaluate_upper(const HostTensorVector& output_values) const override; bool evaluate_upper(const HostTensorVector& output_values) const override;

View File

@ -16,15 +16,11 @@
#include "ngraph/type/element_type_traits.hpp" #include "ngraph/type/element_type_traits.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Class for constants. /// \brief Class for constants.
class NGRAPH_API Constant : public Op class NGRAPH_API Constant : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -41,13 +37,8 @@ namespace ngraph
/// \param values A vector of literals for initializing the tensor constant. The /// \param values A vector of literals for initializing the tensor constant. The
/// size of values must match the size of the shape. /// size of values must match the size of the shape.
template <typename T> template <typename T>
Constant(const element::Type& type, Constant(const element::Type& type, const Shape& shape, const std::vector<T>& values) : Constant(type, shape) {
const Shape& shape, NODE_VALIDATION_CHECK(this,
const std::vector<T>& values)
: Constant(type, shape)
{
NODE_VALIDATION_CHECK(
this,
values.size() == 1 || values.size() == shape_size(m_shape), values.size() == 1 || values.size() == shape_size(m_shape),
"Did not get the expected number of literals for a constant of shape ", "Did not get the expected number of literals for a constant of shape ",
m_shape, m_shape,
@ -58,12 +49,9 @@ namespace ngraph
shape_size(m_shape), shape_size(m_shape),
")."); ").");
if (values.size() == 1) if (values.size() == 1) {
{
fill_data(type, values.front()); fill_data(type, values.front());
} } else {
else
{
write_values(values); write_values(values);
} }
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical(); m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
@ -77,44 +65,72 @@ namespace ngraph
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param value A scalar for initializing the uniform tensor constant. The /// \param value A scalar for initializing the uniform tensor constant. The
/// value is broadcast to the specified shape. /// value is broadcast to the specified shape.
template <class T, template <class T, class = typename std::enable_if<std::is_fundamental<T>::value>::type>
class = typename std::enable_if<std::is_fundamental<T>::value>::type> Constant(const element::Type& type, const Shape& shape, T value) : Constant(type, shape) {
Constant(const element::Type& type, const Shape& shape, T value)
: Constant(type, shape)
{
fill_data(type, value); fill_data(type, value);
m_all_elements_bitwise_identical = true; m_all_elements_bitwise_identical = true;
} }
template <typename T> template <typename T>
void fill_data(const element::Type& type, T value) void fill_data(const element::Type& type, T value) {
{
using Type_t = element::Type_t; using Type_t = element::Type_t;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic push # pragma GCC diagnostic push
# pragma GCC diagnostic error "-Wswitch" # pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum" # pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (type) switch (type) {
{ case Type_t::boolean:
case Type_t::boolean: fill_data<Type_t::boolean>(value); break; fill_data<Type_t::boolean>(value);
case Type_t::bf16: fill_data<Type_t::bf16>(value); break; break;
case Type_t::f16: fill_data<Type_t::f16>(value); break; case Type_t::bf16:
case Type_t::f32: fill_data<Type_t::f32>(value); break; fill_data<Type_t::bf16>(value);
case Type_t::f64: fill_data<Type_t::f64>(value); break; break;
case Type_t::i4: fill_data<Type_t::i4>(value); break; case Type_t::f16:
case Type_t::i8: fill_data<Type_t::i8>(value); break; fill_data<Type_t::f16>(value);
case Type_t::i16: fill_data<Type_t::i16>(value); break; break;
case Type_t::i32: fill_data<Type_t::i32>(value); break; case Type_t::f32:
case Type_t::i64: fill_data<Type_t::i64>(value); break; fill_data<Type_t::f32>(value);
case Type_t::u1: fill_data<Type_t::u1>(value); break; break;
case Type_t::u4: fill_data<Type_t::u4>(value); break; case Type_t::f64:
case Type_t::u8: fill_data<Type_t::u8>(value); break; fill_data<Type_t::f64>(value);
case Type_t::u16: fill_data<Type_t::u16>(value); break; break;
case Type_t::u32: fill_data<Type_t::u32>(value); break; case Type_t::i4:
case Type_t::u64: fill_data<Type_t::u64>(value); break; fill_data<Type_t::i4>(value);
break;
case Type_t::i8:
fill_data<Type_t::i8>(value);
break;
case Type_t::i16:
fill_data<Type_t::i16>(value);
break;
case Type_t::i32:
fill_data<Type_t::i32>(value);
break;
case Type_t::i64:
fill_data<Type_t::i64>(value);
break;
case Type_t::u1:
fill_data<Type_t::u1>(value);
break;
case Type_t::u4:
fill_data<Type_t::u4>(value);
break;
case Type_t::u8:
fill_data<Type_t::u8>(value);
break;
case Type_t::u16:
fill_data<Type_t::u16>(value);
break;
case Type_t::u32:
fill_data<Type_t::u32>(value);
break;
case Type_t::u64:
fill_data<Type_t::u64>(value);
break;
case Type_t::undefined: case Type_t::undefined:
case Type_t::dynamic: throw std::runtime_error("unsupported type"); case Type_t::dynamic:
throw std::runtime_error("unsupported type");
} }
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic pop # pragma GCC diagnostic pop
@ -127,9 +143,7 @@ namespace ngraph
/// \param type The element type of the tensor constant. /// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data. /// \param values A list of string values to use as the constant data.
Constant(const element::Type& type, Constant(const element::Type& type, const Shape& shape, const std::vector<std::string>& values);
const Shape& shape,
const std::vector<std::string>& values);
/// \brief Constructs a tensor constant with the supplied data /// \brief Constructs a tensor constant with the supplied data
/// ///
@ -144,12 +158,9 @@ namespace ngraph
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param data A pointer to pre-allocated shared data. /// \param data A pointer to pre-allocated shared data.
template <typename T> template <typename T>
Constant(const element::Type& type, Constant(const element::Type& type, const Shape& shape, std::shared_ptr<runtime::SharedBuffer<T>> data)
const Shape& shape, : m_element_type(type),
std::shared_ptr<runtime::SharedBuffer<T>> data) m_shape(shape) {
: m_element_type(type)
, m_shape(shape)
{
m_data = data; m_data = data;
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
@ -160,23 +171,20 @@ namespace ngraph
virtual ~Constant() override; virtual ~Constant() override;
void validate_and_infer_types() override void validate_and_infer_types() override {
{
infer_element_type(); infer_element_type();
set_output_type(0, m_element_type, m_shape); set_output_type(0, m_element_type, m_shape);
} }
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& outputs) const override; bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override; bool evaluate_upper(const HostTensorVector& outputs) const override;
// Don't constant fold a constant; it would make a copy // Don't constant fold a constant; it would make a copy
bool constant_fold(OutputVector& outputs, const OutputVector& inputs) override bool constant_fold(OutputVector& outputs, const OutputVector& inputs) override {
{
(void)outputs; (void)outputs;
(void)inputs; (void)inputs;
return false; return false;
@ -227,8 +235,7 @@ namespace ngraph
template <typename T> template <typename T>
static std::shared_ptr<Constant> create(const element::Type& type, static std::shared_ptr<Constant> create(const element::Type& type,
const Shape& shape, const Shape& shape,
const std::vector<T>& values) const std::vector<T>& values) {
{
return std::make_shared<Constant>(type, shape, values); return std::make_shared<Constant>(type, shape, values);
} }
@ -240,8 +247,7 @@ namespace ngraph
template <typename T> template <typename T>
static std::shared_ptr<Constant> create(const element::Type& type, static std::shared_ptr<Constant> create(const element::Type& type,
const Shape& shape, const Shape& shape,
std::initializer_list<T> values) std::initializer_list<T> values) {
{
return std::make_shared<Constant>(type, shape, std::vector<T>{values}); return std::make_shared<Constant>(type, shape, std::vector<T>{values});
} }
@ -250,21 +256,17 @@ namespace ngraph
/// \param type The element type of the tensor constant. /// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param memory An continues memory chunk which contains the constant data. /// \param memory An continues memory chunk which contains the constant data.
static std::shared_ptr<Constant> static std::shared_ptr<Constant> create(const element::Type& type, const Shape& shape, const void* memory) {
create(const element::Type& type, const Shape& shape, const void* memory)
{
return std::make_shared<Constant>(type, shape, memory); return std::make_shared<Constant>(type, shape, memory);
} }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The initialization literals for the tensor constant. /// \return The initialization literals for the tensor constant.
std::vector<std::string> get_value_strings() const; std::vector<std::string> get_value_strings() const;
template <typename T> template <typename T>
std::vector<T> get_vector() const std::vector<T> get_vector() const {
{
const T* p = get_data_ptr<T>(); const T* p = get_data_ptr<T>();
if (p == nullptr) if (p == nullptr)
throw std::runtime_error("Cannot create vector! Buffer is not allocated."); throw std::runtime_error("Cannot create vector! Buffer is not allocated.");
@ -276,8 +278,7 @@ namespace ngraph
/// \tparam T Type to which data vector's entries will be cast. /// \tparam T Type to which data vector's entries will be cast.
/// \return Constant's data vector. /// \return Constant's data vector.
template <typename T> template <typename T>
std::vector<T> cast_vector() const std::vector<T> cast_vector() const {
{
auto source_type = get_element_type(); auto source_type = get_element_type();
std::vector<T> rc; std::vector<T> rc;
using Type_t = element::Type_t; using Type_t = element::Type_t;
@ -285,25 +286,57 @@ namespace ngraph
# pragma warning(push) # pragma warning(push)
# pragma warning(disable : 4244) # pragma warning(disable : 4244)
#endif #endif
switch (source_type) switch (source_type) {
{ case Type_t::boolean:
case Type_t::boolean: cast_vector<Type_t::boolean>(rc); break; cast_vector<Type_t::boolean>(rc);
case Type_t::bf16: cast_vector<Type_t::bf16>(rc); break; break;
case Type_t::f16: cast_vector<Type_t::f16>(rc); break; case Type_t::bf16:
case Type_t::f32: cast_vector<Type_t::f32>(rc); break; cast_vector<Type_t::bf16>(rc);
case Type_t::f64: cast_vector<Type_t::f64>(rc); break; break;
case Type_t::i4: cast_vector<Type_t::i4>(rc); break; case Type_t::f16:
case Type_t::i8: cast_vector<Type_t::i8>(rc); break; cast_vector<Type_t::f16>(rc);
case Type_t::i16: cast_vector<Type_t::i16>(rc); break; break;
case Type_t::i32: cast_vector<Type_t::i32>(rc); break; case Type_t::f32:
case Type_t::i64: cast_vector<Type_t::i64>(rc); break; cast_vector<Type_t::f32>(rc);
case Type_t::u1: cast_vector<Type_t::u1>(rc); break; break;
case Type_t::u4: cast_vector<Type_t::u4>(rc); break; case Type_t::f64:
case Type_t::u8: cast_vector<Type_t::u8>(rc); break; cast_vector<Type_t::f64>(rc);
case Type_t::u16: cast_vector<Type_t::u16>(rc); break; break;
case Type_t::u32: cast_vector<Type_t::u32>(rc); break; case Type_t::i4:
case Type_t::u64: cast_vector<Type_t::u64>(rc); break; cast_vector<Type_t::i4>(rc);
default: throw std::runtime_error("unsupported type"); break;
case Type_t::i8:
cast_vector<Type_t::i8>(rc);
break;
case Type_t::i16:
cast_vector<Type_t::i16>(rc);
break;
case Type_t::i32:
cast_vector<Type_t::i32>(rc);
break;
case Type_t::i64:
cast_vector<Type_t::i64>(rc);
break;
case Type_t::u1:
cast_vector<Type_t::u1>(rc);
break;
case Type_t::u4:
cast_vector<Type_t::u4>(rc);
break;
case Type_t::u8:
cast_vector<Type_t::u8>(rc);
break;
case Type_t::u16:
cast_vector<Type_t::u16>(rc);
break;
case Type_t::u32:
cast_vector<Type_t::u32>(rc);
break;
case Type_t::u64:
cast_vector<Type_t::u64>(rc);
break;
default:
throw std::runtime_error("unsupported type");
} }
#if defined(_MSC_VER) #if defined(_MSC_VER)
# pragma warning(pop) # pragma warning(pop)
@ -311,12 +344,12 @@ namespace ngraph
return rc; return rc;
} }
const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); } const void* get_data_ptr() const {
return (m_data ? m_data->get_ptr() : nullptr);
}
template <typename T> template <typename T>
const T* get_data_ptr() const const T* get_data_ptr() const {
{ if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0) {
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
{
throw ngraph_error("Buffer over-read"); throw ngraph_error("Buffer over-read");
} }
@ -324,16 +357,12 @@ namespace ngraph
} }
template <element::Type_t ET> template <element::Type_t ET>
const typename element_type_traits<ET>::value_type* get_data_ptr() const const typename element_type_traits<ET>::value_type* get_data_ptr() const {
{ NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type.");
NGRAPH_CHECK(ET == get_element_type(), return static_cast<const typename element_type_traits<ET>::value_type*>(get_data_ptr());
"get_data_ptr() called for incorrect element type.");
return static_cast<const typename element_type_traits<ET>::value_type*>(
get_data_ptr());
} }
bool get_all_data_elements_bitwise_identical() const bool get_all_data_elements_bitwise_identical() const {
{
return m_all_elements_bitwise_identical; return m_all_elements_bitwise_identical;
} }
std::string convert_value_to_string(size_t index) const; std::string convert_value_to_string(size_t index) const;
@ -341,46 +370,39 @@ namespace ngraph
/** /**
* \brief Allows to avoid buffer allocation on the visit_attributes call * \brief Allows to avoid buffer allocation on the visit_attributes call
*/ */
void alloc_buffer_on_visit_attributes(bool val) void alloc_buffer_on_visit_attributes(bool val) {
{
m_alloc_buffer_on_visit_attributes = val; m_alloc_buffer_on_visit_attributes = val;
} }
private: private:
template <element::Type_t Type, template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type != element::Type_t::u1 && typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::u4 &&
Type != element::Type_t::i4, Type != element::Type_t::i4,
bool>::type = true> bool>::type = true>
StorageDataType get_element_value(size_t index) const StorageDataType get_element_value(size_t index) const {
{
return get_data_ptr<Type>()[index]; return get_data_ptr<Type>()[index];
} }
template <element::Type_t Type, template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true> typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
StorageDataType get_element_value(size_t index) const StorageDataType get_element_value(size_t index) const {
{
return (get_data_ptr<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1; return (get_data_ptr<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1;
} }
template <element::Type_t Type, template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true> typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
StorageDataType get_element_value(size_t index) const StorageDataType get_element_value(size_t index) const {
{
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F; return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
} }
template <element::Type_t Type, template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true> typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
StorageDataType get_element_value(size_t index) const StorageDataType get_element_value(size_t index) const {
{ const uint8_t i4data = (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
const uint8_t i4data =
(get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
const bool is_negative_number = (i4data >> 3) & 0x01; const bool is_negative_number = (i4data >> 3) & 0x01;
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data; const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
return data; return data;
@ -388,12 +410,10 @@ namespace ngraph
template <element::Type_t Type, template <element::Type_t Type,
typename OUT_T, typename OUT_T,
typename std::enable_if<Type != element::Type_t::u1 && typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::u4 &&
Type != element::Type_t::i4, Type != element::Type_t::i4,
bool>::type = true> bool>::type = true>
void cast_vector(std::vector<OUT_T>& output_vector) const void cast_vector(std::vector<OUT_T>& output_vector) const {
{
// this function is workaround for waring during windows building // this function is workaround for waring during windows building
// build complains for vector creation based on iterators // build complains for vector creation based on iterators
// which point on different type than destination vector::value_type // which point on different type than destination vector::value_type
@ -401,28 +421,23 @@ namespace ngraph
auto source_vector = get_vector<IN_T>(); auto source_vector = get_vector<IN_T>();
output_vector.reserve(source_vector.size()); output_vector.reserve(source_vector.size());
std::transform(source_vector.begin(), std::transform(source_vector.begin(), source_vector.end(), std::back_inserter(output_vector), [](IN_T c) {
source_vector.end(), return static_cast<OUT_T>(c);
std::back_inserter(output_vector), });
[](IN_T c) { return static_cast<OUT_T>(c); });
} }
template <element::Type_t Type, template <element::Type_t Type,
typename OUT_T, typename OUT_T,
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true> typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
void cast_vector(std::vector<OUT_T>& output) const void cast_vector(std::vector<OUT_T>& output) const {
{
using IN_T = fundamental_type_for<Type>; using IN_T = fundamental_type_for<Type>;
const auto element_number = shape_size(m_shape); const auto element_number = shape_size(m_shape);
const auto source_begin = get_data_ptr<uint8_t>(); const auto source_begin = get_data_ptr<uint8_t>();
const auto source_end = std::next(source_begin, (element_number + 7) / 8); const auto source_end = std::next(source_begin, (element_number + 7) / 8);
const auto round_element_no = element_number % 8 const auto round_element_no = element_number % 8 ? element_number - element_number % 8 + 8 : element_number;
? element_number - element_number % 8 + 8
: element_number;
output.reserve(round_element_no); // adds 7 more elements here? output.reserve(round_element_no); // adds 7 more elements here?
std::for_each(source_begin, source_end, [&](IN_T c) { std::for_each(source_begin, source_end, [&](IN_T c) {
for (const auto i : {7, 6, 5, 4, 3, 2, 1, 0}) for (const auto i : {7, 6, 5, 4, 3, 2, 1, 0}) {
{
const uint8_t data = (c >> i) & 0x01; const uint8_t data = (c >> i) & 0x01;
output.push_back(data); output.push_back(data);
} }
@ -433,18 +448,15 @@ namespace ngraph
template <element::Type_t Type, template <element::Type_t Type,
typename OUT_T, typename OUT_T,
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true> typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
void cast_vector(std::vector<OUT_T>& output) const void cast_vector(std::vector<OUT_T>& output) const {
{
using IN_T = fundamental_type_for<Type>; using IN_T = fundamental_type_for<Type>;
const auto element_number = shape_size(m_shape); const auto element_number = shape_size(m_shape);
const auto source_begin = get_data_ptr<uint8_t>(); const auto source_begin = get_data_ptr<uint8_t>();
const auto source_end = std::next(source_begin, (element_number + 1) / 2); const auto source_end = std::next(source_begin, (element_number + 1) / 2);
const auto round_element_no = const auto round_element_no = element_number % 2 ? element_number + 1 : element_number;
element_number % 2 ? element_number + 1 : element_number;
output.reserve(round_element_no); // adds 1 more elements here? output.reserve(round_element_no); // adds 1 more elements here?
std::for_each(source_begin, source_end, [&](IN_T c) { std::for_each(source_begin, source_end, [&](IN_T c) {
for (const auto i : {4, 0}) for (const auto i : {4, 0}) {
{
const uint8_t data = (c >> i) & 0x0F; const uint8_t data = (c >> i) & 0x0F;
output.push_back(data); output.push_back(data);
} }
@ -454,18 +466,15 @@ namespace ngraph
template <element::Type_t Type, template <element::Type_t Type,
typename OUT_T, typename OUT_T,
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true> typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
void cast_vector(std::vector<OUT_T>& output) const void cast_vector(std::vector<OUT_T>& output) const {
{
using IN_T = fundamental_type_for<Type>; using IN_T = fundamental_type_for<Type>;
const auto element_number = shape_size(m_shape); const auto element_number = shape_size(m_shape);
const auto source_begin = get_data_ptr<uint8_t>(); const auto source_begin = get_data_ptr<uint8_t>();
const auto source_end = std::next(source_begin, (element_number + 1) / 2); const auto source_end = std::next(source_begin, (element_number + 1) / 2);
const auto round_element_no = const auto round_element_no = element_number % 2 ? element_number + 1 : element_number;
element_number % 2 ? element_number + 1 : element_number;
output.reserve(round_element_no); // adds 1 more elements here? output.reserve(round_element_no); // adds 1 more elements here?
std::for_each(source_begin, source_end, [&](IN_T c) { std::for_each(source_begin, source_end, [&](IN_T c) {
for (const auto i : {4, 0}) for (const auto i : {4, 0}) {
{
const uint8_t i4data = (c >> i) & 0x0F; const uint8_t i4data = (c >> i) & 0x0F;
const bool is_negative_number = (i4data >> 3) & 0x01; const bool is_negative_number = (i4data >> 3) & 0x01;
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data; const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
@ -478,12 +487,10 @@ namespace ngraph
template <element::Type_t Type, template <element::Type_t Type,
typename T, typename T,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type != element::Type_t::u1 && typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::u4 &&
Type != element::Type_t::i4, Type != element::Type_t::i4,
bool>::type = true> bool>::type = true>
void fill_data(const T& value) void fill_data(const T& value) {
{
const auto size = shape_size(m_shape); const auto size = shape_size(m_shape);
const auto v = static_cast<StorageDataType>(value); const auto v = static_cast<StorageDataType>(value);
std::fill_n(get_data_ptr_nc<Type>(), size, v); std::fill_n(get_data_ptr_nc<Type>(), size, v);
@ -493,8 +500,7 @@ namespace ngraph
typename T, typename T,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true> typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
void fill_data(const T& value) void fill_data(const T& value) {
{
const StorageDataType v = value ? 0xFF : 0x00; const StorageDataType v = value ? 0xFF : 0x00;
std::fill_n(get_data_ptr_nc<Type>(), mem_size(), v); std::fill_n(get_data_ptr_nc<Type>(), mem_size(), v);
} }
@ -502,11 +508,8 @@ namespace ngraph
template <element::Type_t Type, template <element::Type_t Type,
typename T, typename T,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u4 || typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool>::type = true>
Type == element::Type_t::i4, void fill_data(const T& value) {
bool>::type = true>
void fill_data(const T& value)
{
uint8_t v = value_in_range<Type>(value); uint8_t v = value_in_range<Type>(value);
v &= 0x0F; v &= 0x0F;
v += v << 4; v += v << 4;
@ -515,42 +518,33 @@ namespace ngraph
void allocate_buffer(); void allocate_buffer();
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); } void* get_data_ptr_nc() {
return (m_data ? m_data->get_ptr() : nullptr);
}
template <element::Type_t ET> template <element::Type_t ET>
typename element_type_traits<ET>::value_type* get_data_ptr_nc() typename element_type_traits<ET>::value_type* get_data_ptr_nc() {
{ NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr_nc() called for incorrect element type.");
NGRAPH_CHECK(ET == get_element_type(), return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr_nc());
"get_data_ptr_nc() called for incorrect element type.");
return static_cast<typename element_type_traits<ET>::value_type*>(
get_data_ptr_nc());
} }
Constant(const OutputVector& args) Constant(const OutputVector& args) : Op(args), m_shape({}) {}
: Op(args)
, m_shape({})
{
}
virtual void infer_element_type() {} virtual void infer_element_type() {}
template <typename T> template <typename T>
void write_values(const std::vector<T>& values) void write_values(const std::vector<T>& values) {
{
write_to_buffer(values); write_to_buffer(values);
} }
template <element::Type_t Type, template <element::Type_t Type,
typename T, typename T,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type != element::Type_t::u1 && typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
Type != element::Type_t::u4 &&
Type != element::Type_t::i4, Type != element::Type_t::i4,
bool>::type = true> bool>::type = true>
void write_buffer(const std::vector<T>& source) void write_buffer(const std::vector<T>& source) {
{
auto p = get_data_ptr_nc<Type>(); auto p = get_data_ptr_nc<Type>();
for (size_t i = 0; i < source.size(); i++) for (size_t i = 0; i < source.size(); i++) {
{
p[i] = static_cast<StorageDataType>(source[i]); p[i] = static_cast<StorageDataType>(source[i]);
} }
} }
@ -558,22 +552,17 @@ namespace ngraph
template <element::Type_t Type, template <element::Type_t Type,
typename T, typename T,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u4 || typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool>::type = true>
Type == element::Type_t::i4, void write_buffer(const std::vector<T>& source) {
bool>::type = true>
void write_buffer(const std::vector<T>& source)
{
auto p = get_data_ptr_nc<Type>(); auto p = get_data_ptr_nc<Type>();
size_t i = 0; size_t i = 0;
for (; i < source.size() / 2; i++) for (; i < source.size() / 2; i++) {
{
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F; const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F; const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F;
const auto v = (v1 << 4) | v2; const auto v = (v1 << 4) | v2;
p[i] = static_cast<StorageDataType>(v); p[i] = static_cast<StorageDataType>(v);
} }
if (source.size() % 2) if (source.size() % 2) {
{
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F; const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
const auto v = v1 << 4; const auto v = v1 << 4;
p[i] = static_cast<StorageDataType>(v); p[i] = static_cast<StorageDataType>(v);
@ -584,23 +573,19 @@ namespace ngraph
typename T, typename T,
typename StorageDataType = fundamental_type_for<Type>, typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u1, bool>::type = true> typename std::enable_if<Type == element::Type_t::u1, bool>::type = true>
void write_buffer(const std::vector<T>& source) void write_buffer(const std::vector<T>& source) {
{
auto p = get_data_ptr_nc<Type>(); auto p = get_data_ptr_nc<Type>();
size_t i = 0; size_t i = 0;
for (; i < source.size() / 8; i++) for (; i < source.size() / 8; i++) {
{
uint8_t v{}; uint8_t v{};
for (int j = 0; j != 8; j++) for (int j = 0; j != 8; j++) {
{
const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0; const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
v |= b; v |= b;
} }
p[i] = static_cast<StorageDataType>(v); p[i] = static_cast<StorageDataType>(v);
} }
uint8_t v{}; uint8_t v{};
for (unsigned j = 0; j != source.size() % 8; j++) for (unsigned j = 0; j != source.size() % 8; j++) {
{
const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0; const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
v |= b; v |= b;
} }
@ -608,12 +593,10 @@ namespace ngraph
} }
template <typename T> template <typename T>
void write_to_buffer(const std::vector<T>& source) void write_to_buffer(const std::vector<T>& source) {
{
const auto& target_type = m_element_type; const auto& target_type = m_element_type;
size_t target_element_count = shape_size(m_shape); size_t target_element_count = shape_size(m_shape);
if (source.size() != target_element_count) if (source.size() != target_element_count) {
{
throw std::runtime_error("Constant initializer does not match shape"); throw std::runtime_error("Constant initializer does not match shape");
} }
using Type_t = element::Type_t; using Type_t = element::Type_t;
@ -622,63 +605,89 @@ namespace ngraph
# pragma GCC diagnostic error "-Wswitch" # pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum" # pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (target_type) switch (target_type) {
{ case Type_t::boolean:
case Type_t::boolean: write_buffer<Type_t::boolean>(source); break; write_buffer<Type_t::boolean>(source);
case Type_t::bf16: write_buffer<Type_t::bf16>(source); break; break;
case Type_t::f16: write_buffer<Type_t::f16>(source); break; case Type_t::bf16:
case Type_t::f32: write_buffer<Type_t::f32>(source); break; write_buffer<Type_t::bf16>(source);
case Type_t::f64: write_buffer<Type_t::f64>(source); break; break;
case Type_t::i4: write_buffer<Type_t::i4>(source); break; case Type_t::f16:
case Type_t::i8: write_buffer<Type_t::i8>(source); break; write_buffer<Type_t::f16>(source);
case Type_t::i16: write_buffer<Type_t::i16>(source); break; break;
case Type_t::i32: write_buffer<Type_t::i32>(source); break; case Type_t::f32:
case Type_t::i64: write_buffer<Type_t::i64>(source); break; write_buffer<Type_t::f32>(source);
case Type_t::u1: write_buffer<Type_t::u1>(source); break; break;
case Type_t::u4: write_buffer<Type_t::u4>(source); break; case Type_t::f64:
case Type_t::u8: write_buffer<Type_t::u8>(source); break; write_buffer<Type_t::f64>(source);
case Type_t::u16: write_buffer<Type_t::u16>(source); break; break;
case Type_t::u32: write_buffer<Type_t::u32>(source); break; case Type_t::i4:
case Type_t::u64: write_buffer<Type_t::u64>(source); break; write_buffer<Type_t::i4>(source);
break;
case Type_t::i8:
write_buffer<Type_t::i8>(source);
break;
case Type_t::i16:
write_buffer<Type_t::i16>(source);
break;
case Type_t::i32:
write_buffer<Type_t::i32>(source);
break;
case Type_t::i64:
write_buffer<Type_t::i64>(source);
break;
case Type_t::u1:
write_buffer<Type_t::u1>(source);
break;
case Type_t::u4:
write_buffer<Type_t::u4>(source);
break;
case Type_t::u8:
write_buffer<Type_t::u8>(source);
break;
case Type_t::u16:
write_buffer<Type_t::u16>(source);
break;
case Type_t::u32:
write_buffer<Type_t::u32>(source);
break;
case Type_t::u64:
write_buffer<Type_t::u64>(source);
break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type"); case element::Type_t::dynamic:
throw std::runtime_error("unsupported type");
} }
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic pop # pragma GCC diagnostic pop
#endif #endif
} }
template < template <ngraph::element::Type_t Type,
ngraph::element::Type_t Type,
typename ValueT, typename ValueT,
typename std::enable_if<Type == ngraph::element::Type_t::u4, bool>::type = true> typename std::enable_if<Type == ngraph::element::Type_t::u4, bool>::type = true>
static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value) static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value) {
{
const auto result = ngraph::fundamental_type_for<Type>(value); const auto result = ngraph::fundamental_type_for<Type>(value);
NGRAPH_CHECK(0 <= result && result <= 15, NGRAPH_CHECK(0 <= result && result <= 15, "assigned value out of range u4 values");
"assigned value out of range u4 values");
return result; return result;
} }
template < template <ngraph::element::Type_t Type,
ngraph::element::Type_t Type,
typename ValueT, typename ValueT,
typename std::enable_if<Type == ngraph::element::Type_t::i4, bool>::type = true> typename std::enable_if<Type == ngraph::element::Type_t::i4, bool>::type = true>
static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value) static ngraph::fundamental_type_for<Type> value_in_range(const ValueT& value) {
{
const auto result = ngraph::fundamental_type_for<Type>(value); const auto result = ngraph::fundamental_type_for<Type>(value);
NGRAPH_CHECK(-8 <= result && result <= 7, NGRAPH_CHECK(-8 <= result && result <= 7, "assigned value out of range i4 values");
"assigned value out of range i4 values");
return result; return result;
} }
bool are_all_data_elements_bitwise_identical() const; bool are_all_data_elements_bitwise_identical() const;
static constexpr size_t host_alignment() { return 64; } static constexpr size_t host_alignment() {
return 64;
}
size_t mem_size() const size_t mem_size() const {
{
const bool bitwidth_less_than_byte = m_element_type.bitwidth() < 8; const bool bitwidth_less_than_byte = m_element_type.bitwidth() < 8;
if (bitwidth_less_than_byte) if (bitwidth_less_than_byte) {
{
const auto size = shape_size(m_shape); const auto size = shape_size(m_shape);
const auto bitwidth = size * m_element_type.bitwidth(); const auto bitwidth = size * m_element_type.bitwidth();
// for rounding by `(bitwidth + 7) / 8` will work for // for rounding by `(bitwidth + 7) / 8` will work for

View File

@ -7,15 +7,11 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise type conversion operation. /// \brief Elementwise type conversion operation.
class NGRAPH_API Convert : public Op class NGRAPH_API Convert : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -29,21 +25,21 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; const element::Type& get_destination_type() const {
const element::Type& get_destination_type() const { return m_destination_type; } return m_destination_type;
void set_destination_type(const element::Type& destination_type) }
{ void set_destination_type(const element::Type& destination_type) {
m_destination_type = destination_type; m_destination_type = destination_type;
} }
const element::Type& get_convert_element_type() const { return m_destination_type; } const element::Type& get_convert_element_type() const {
void set_convert_element_type(const element::Type& destination_type) return m_destination_type;
{ }
void set_convert_element_type(const element::Type& destination_type) {
m_destination_type = destination_type; m_destination_type = destination_type;
} }
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& outputs) const override; bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override; bool evaluate_upper(const HostTensorVector& outputs) const override;

View File

@ -6,15 +6,11 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Elementwise type conversion operation. /// \brief Elementwise type conversion operation.
class NGRAPH_API ConvertLike : public Op class NGRAPH_API ConvertLike : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -28,11 +24,9 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool constant_fold(OutputVector& output_values, bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override;
const OutputVector& input_values) override;
}; };
} // namespace v1 } // namespace v1

View File

@ -8,16 +8,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Batched convolution operation, with optional window dilation and stride. /// \brief Batched convolution operation, with optional window dilation and stride.
/// ///
class NGRAPH_API Convolution : public Op class NGRAPH_API Convolution : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -53,24 +49,43 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The strides. /// \return The strides.
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const {
void set_strides(const Strides& strides) { m_strides = strides; } return m_strides;
}
void set_strides(const Strides& strides) {
m_strides = strides;
}
/// \return The dilations. /// \return The dilations.
const Strides& get_dilations() const { return m_dilations; } const Strides& get_dilations() const {
void set_dilations(const Strides& dilations) { m_dilations = dilations; } return m_dilations;
}
void set_dilations(const Strides& dilations) {
m_dilations = dilations;
}
/// \return The padding-below sizes (possibly negative). /// \return The padding-below sizes (possibly negative).
const CoordinateDiff& get_pads_begin() const { return m_pads_begin; } const CoordinateDiff& get_pads_begin() const {
void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; } return m_pads_begin;
}
void set_pads_begin(const CoordinateDiff& pads_begin) {
m_pads_begin = pads_begin;
}
/// \return The padding-above sizes (possibly negative). /// \return The padding-above sizes (possibly negative).
const CoordinateDiff& get_pads_end() const { return m_pads_end; } const CoordinateDiff& get_pads_end() const {
void set_adding_above(const CoordinateDiff& pads_end) { m_pads_end = pads_end; } return m_pads_end;
}
void set_adding_above(const CoordinateDiff& pads_end) {
m_pads_end = pads_end;
}
/// \return The pad type for convolution. /// \return The pad type for convolution.
const PadType& get_auto_pad() const { return m_auto_pad; } const PadType& get_auto_pad() const {
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; } return m_auto_pad;
}
void set_auto_pad(const PadType& auto_pad) {
m_auto_pad = auto_pad;
}
/// \return The default value for Convolution. /// \return The default value for Convolution.
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
@ -85,8 +100,7 @@ namespace ngraph
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
class NGRAPH_API ConvolutionBackpropData : public Op class NGRAPH_API ConvolutionBackpropData : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise cosine operation. /// \brief Elementwise cosine operation.
class NGRAPH_API Cos : public util::UnaryElementwiseArithmetic class NGRAPH_API Cos : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -26,10 +22,8 @@ namespace ngraph
Cos(const Output<Node>& arg); Cos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise hyperbolic cosine (cosh) operation. /// \brief Elementwise hyperbolic cosine (cosh) operation.
class NGRAPH_API Cosh : public util::UnaryElementwiseArithmetic class NGRAPH_API Cosh : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -26,10 +22,8 @@ namespace ngraph
Cosh(const Output<Node>& arg); Cosh(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -6,14 +6,10 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{ class NGRAPH_API CTCGreedyDecoder : public Op {
namespace v0
{
class NGRAPH_API CTCGreedyDecoder : public Op
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -23,16 +19,15 @@ namespace ngraph
/// \param input Logits on which greedy decoding is performed /// \param input Logits on which greedy decoding is performed
/// \param seq_len Sequence lengths /// \param seq_len Sequence lengths
/// \param ctc_merge_repeated Whether to merge repeated labels /// \param ctc_merge_repeated Whether to merge repeated labels
CTCGreedyDecoder(const Output<Node>& input, CTCGreedyDecoder(const Output<Node>& input, const Output<Node>& seq_len, const bool ctc_merge_repeated);
const Output<Node>& seq_len,
const bool ctc_merge_repeated);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_ctc_merge_repeated() const { return m_ctc_merge_repeated; } bool get_ctc_merge_repeated() const {
return m_ctc_merge_repeated;
}
private: private:
bool m_ctc_merge_repeated; bool m_ctc_merge_repeated;

View File

@ -6,16 +6,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief Operator performing CTCGreedyDecoder /// \brief Operator performing CTCGreedyDecoder
/// ///
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op class NGRAPH_API CTCGreedyDecoderSeqLen : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
CTCGreedyDecoderSeqLen() = default; CTCGreedyDecoderSeqLen() = default;
@ -52,25 +48,27 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Get merge_repeated attribute /// \brief Get merge_repeated attribute
/// ///
/// \return Current value of merge_repeated attribute /// \return Current value of merge_repeated attribute
/// ///
bool get_merge_repeated() const { return m_merge_repeated; } bool get_merge_repeated() const {
return m_merge_repeated;
}
/// \brief Get classes_index_type attribute /// \brief Get classes_index_type attribute
/// ///
/// \return Current value of classes_index_type attribute /// \return Current value of classes_index_type attribute
/// ///
const element::Type& get_classes_index_type() const { return m_classes_index_type; } const element::Type& get_classes_index_type() const {
return m_classes_index_type;
}
/// \brief Set classes_index_type attribute /// \brief Set classes_index_type attribute
/// ///
/// \param classes_index_type Type of classes_index /// \param classes_index_type Type of classes_index
/// ///
void set_classes_index_type(const element::Type& classes_index_type) void set_classes_index_type(const element::Type& classes_index_type) {
{
m_classes_index_type = classes_index_type; m_classes_index_type = classes_index_type;
validate_and_infer_types(); validate_and_infer_types();
} }
@ -79,8 +77,7 @@ namespace ngraph
/// ///
/// \return Current value of sequence_length_type attribute /// \return Current value of sequence_length_type attribute
/// ///
const element::Type& get_sequence_length_type() const const element::Type& get_sequence_length_type() const {
{
return m_sequence_length_type; return m_sequence_length_type;
} }
@ -88,8 +85,7 @@ namespace ngraph
/// ///
/// \param sequence_length_type Type of sequence length /// \param sequence_length_type Type of sequence length
/// ///
void set_sequence_length_type(const element::Type& sequence_length_type) void set_sequence_length_type(const element::Type& sequence_length_type) {
{
m_sequence_length_type = sequence_length_type; m_sequence_length_type = sequence_length_type;
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -6,17 +6,15 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v4 {
{ class NGRAPH_API CTCLoss : public Op {
namespace v4
{
class NGRAPH_API CTCLoss : public Op
{
public: public:
static constexpr NodeTypeInfo type_info{"CTCLoss", 0}; static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
CTCLoss() = default; CTCLoss() = default;
/// \brief Constructs a CTCLoss operation /// \brief Constructs a CTCLoss operation
/// ///
@ -53,15 +51,17 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_preprocess_collapse_repeated() const bool get_preprocess_collapse_repeated() const {
{
return preprocess_collapse_repeated_; return preprocess_collapse_repeated_;
} }
bool get_ctc_merge_repeated() const { return ctc_merge_repeated_; } bool get_ctc_merge_repeated() const {
bool get_unique() const { return unique_; } return ctc_merge_repeated_;
}
bool get_unique() const {
return unique_;
}
private: private:
bool preprocess_collapse_repeated_; bool preprocess_collapse_repeated_;

View File

@ -7,12 +7,9 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Tensor cumulative sum operation. /// \brief Tensor cumulative sum operation.
/// ///
/// Compute the cumulative sum of the input tensor along the axis specified. /// Compute the cumulative sum of the input tensor along the axis specified.
@ -57,8 +54,7 @@ namespace ngraph
/// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements /// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements
/// | /// |
class NGRAPH_API CumSum : public Op class NGRAPH_API CumSum : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -72,20 +68,14 @@ namespace ngraph
/// cumulative sum must be performed /// cumulative sum must be performed
/// \param exclusive if set to true, the top element is not included /// \param exclusive if set to true, the top element is not included
/// \param reverse if set to true, will perform the sums in reverse direction /// \param reverse if set to true, will perform the sums in reverse direction
CumSum(const Output<Node>& arg, CumSum(const Output<Node>& arg, const Output<Node>& axis, const bool exclusive = false, const bool reverse = false);
const Output<Node>& axis,
const bool exclusive = false,
const bool reverse = false);
/// \brief Constructs a cumulative summation operation with axis = 0 /// \brief Constructs a cumulative summation operation with axis = 0
/// ///
/// \param arg The tensor to be summed /// \param arg The tensor to be summed
CumSum(const Output<Node>& arg, CumSum(const Output<Node>& arg, const bool exclusive = false, const bool reverse = false);
const bool exclusive = false,
const bool reverse = false);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
@ -94,8 +84,12 @@ namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
bool is_exclusive() const { return m_exclusive; } bool is_exclusive() const {
bool is_reverse() const { return m_reverse; } return m_exclusive;
}
bool is_reverse() const {
return m_reverse;
}
private: private:
bool m_exclusive; bool m_exclusive;

View File

@ -9,15 +9,11 @@
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/deformable_convolution_base.hpp" #include "ngraph/op/util/deformable_convolution_base.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief DeformableConvolution operation. /// \brief DeformableConvolution operation.
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -55,15 +51,12 @@ namespace ngraph
const int64_t group = 1, const int64_t group = 1,
const int64_t deformable_group = 1); const int64_t deformable_group = 1);
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v1 } // namespace v1
namespace v8 namespace v8 {
{ class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase {
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -160,18 +153,17 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_bilinear_interpolation_pad() const { return m_bilinear_interpolation_pad; } bool get_bilinear_interpolation_pad() const {
return m_bilinear_interpolation_pad;
}
void set_bilinear_interpolation_pad(const bool bilinear_interpolation_pad) void set_bilinear_interpolation_pad(const bool bilinear_interpolation_pad) {
{
m_bilinear_interpolation_pad = bilinear_interpolation_pad; m_bilinear_interpolation_pad = bilinear_interpolation_pad;
} }

View File

@ -6,14 +6,10 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{ class NGRAPH_API DeformablePSROIPooling : public Op {
namespace v1
{
class NGRAPH_API DeformablePSROIPooling : public Op
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -69,17 +65,32 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_output_dim() const { return m_output_dim; } int64_t get_output_dim() const {
int64_t get_group_size() const { return m_group_size; } return m_output_dim;
float get_spatial_scale() const { return m_spatial_scale; } }
const std::string& get_mode() const { return m_mode; } int64_t get_group_size() const {
int64_t get_spatial_bins_x() const { return m_spatial_bins_x; } return m_group_size;
int64_t get_spatial_bins_y() const { return m_spatial_bins_y; } }
float get_trans_std() const { return m_trans_std; } float get_spatial_scale() const {
int64_t get_part_size() const { return m_part_size; } return m_spatial_scale;
}
const std::string& get_mode() const {
return m_mode;
}
int64_t get_spatial_bins_x() const {
return m_spatial_bins_x;
}
int64_t get_spatial_bins_y() const {
return m_spatial_bins_y;
}
float get_trans_std() const {
return m_trans_std;
}
int64_t get_part_size() const {
return m_part_size;
}
private: private:
int64_t m_output_dim; int64_t m_output_dim;

View File

@ -7,12 +7,9 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief DepthToSpace permutes data from the depth dimension of the input blob into /// \brief DepthToSpace permutes data from the depth dimension of the input blob into
/// spatial dimensions. /// spatial dimensions.
/// ///
@ -21,13 +18,11 @@ namespace ngraph
/// ///
/// Output node produces a tensor with shape: /// Output node produces a tensor with shape:
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize] /// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
class NGRAPH_API DepthToSpace : public Op class NGRAPH_API DepthToSpace : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
enum class DepthToSpaceMode enum class DepthToSpaceMode {
{
// The input depth is divided to [block_size, ..., block_size, new_depth] // The input depth is divided to [block_size, ..., block_size, new_depth]
BLOCKS_FIRST, BLOCKS_FIRST,
// The input depth is divided to [new_depth, block_size, ..., block_size] // The input depth is divided to [new_depth, block_size, ..., block_size]
@ -41,22 +36,20 @@ namespace ngraph
/// \param mode Specifies how the input depth dimension is split to block /// \param mode Specifies how the input depth dimension is split to block
/// coordinates /// coordinates
/// \param block_size The size of the block of values to be moved /// \param block_size The size of the block of values to be moved
DepthToSpace(const Output<Node>& data, DepthToSpace(const Output<Node>& data, const DepthToSpaceMode& mode, std::size_t block_size = 1);
const DepthToSpaceMode& mode,
std::size_t block_size = 1);
DepthToSpace(const Output<Node>& data, DepthToSpace(const Output<Node>& data, const std::string& mode, std::size_t block_size = 1);
const std::string& mode,
std::size_t block_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::size_t get_block_size() const { return m_blocksize; } std::size_t get_block_size() const {
DepthToSpaceMode get_mode() const { return m_mode; } return m_blocksize;
virtual std::shared_ptr<Node> }
clone_with_new_inputs(const OutputVector& new_args) const override; DepthToSpaceMode get_mode() const {
return m_mode;
}
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
protected: protected:
@ -72,16 +65,14 @@ namespace ngraph
template <> template <>
class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode> class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>
: public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode> : public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode> {
{
public: public:
AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value) AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value)
: EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value) : EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value) {}
{
}
static constexpr DiscreteTypeInfo type_info{ static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0};
"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0}; const DiscreteTypeInfo& get_type_info() const override {
const DiscreteTypeInfo& get_type_info() const override { return type_info; } return type_info;
}
}; };
} // namespace ngraph } // namespace ngraph

View File

@ -6,12 +6,9 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op struct DetectionOutputAttrs {
{
struct DetectionOutputAttrs
{
int num_classes; int num_classes;
int background_label_id = 0; int background_label_id = 0;
int top_k = -1; int top_k = -1;
@ -30,12 +27,10 @@ namespace ngraph
float objectness_score = 0; float objectness_score = 0;
}; };
namespace v0 namespace v0 {
{
/// \brief Layer which performs non-max suppression to /// \brief Layer which performs non-max suppression to
/// generate detection output using location and confidence predictions /// generate detection output using location and confidence predictions
class NGRAPH_API DetectionOutput : public Op class NGRAPH_API DetectionOutput : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -68,10 +63,11 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
const DetectionOutputAttrs& get_attrs() const { return m_attrs; } const DetectionOutputAttrs& get_attrs() const {
return m_attrs;
}
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
private: private:

View File

@ -18,20 +18,17 @@
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fft_base.hpp" #include "ngraph/op/util/fft_base.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v7 {
{
namespace v7
{
/// \brief An operation DFT that computes the discrete Fourier transformation. /// \brief An operation DFT that computes the discrete Fourier transformation.
class NGRAPH_API DFT : public util::FFTBase class NGRAPH_API DFT : public util::FFTBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
DFT() = default; DFT() = default;
@ -47,14 +44,11 @@ namespace ngraph
/// \param data Input data /// \param data Input data
/// \param axes Axes to perform DFT /// \param axes Axes to perform DFT
/// \param signal_size Signal sizes for 'axes' /// \param signal_size Signal sizes for 'axes'
DFT(const Output<Node>& data, DFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
const Output<Node>& axes,
const Output<Node>& signal_size);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v7 } // namespace v7
} // namespace op } // namespace op

View File

@ -6,22 +6,15 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Elementwise division operation. /// \brief Elementwise division operation.
class NGRAPH_API Divide : public util::BinaryElementwiseArithmetic class NGRAPH_API Divide : public util::BinaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
Divide() Divide() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) {}
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
@ -32,8 +25,7 @@ namespace ngraph
Divide(const Output<Node>& arg0, Divide(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
bool pythondiv, bool pythondiv,
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
@ -42,16 +34,17 @@ namespace ngraph
/// \param auto_broadcast Auto broadcast specification /// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0, Divide(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
bool is_pythondiv() const { return m_pythondiv; } bool is_pythondiv() const {
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; } return m_pythondiv;
virtual std::shared_ptr<Node> }
clone_with_new_inputs(const OutputVector& new_args) const override; void set_is_pythondiv(bool pythondiv) {
m_pythondiv = pythondiv;
}
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
protected: protected:

View File

@ -7,15 +7,11 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v7 {
{
namespace v7
{
/// \brief Einsum operation. /// \brief Einsum operation.
class NGRAPH_API Einsum : public Op class NGRAPH_API Einsum : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -35,14 +31,15 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Get an equation of Einsum operation /// \brief Get an equation of Einsum operation
/// ///
/// \return Einsum equation /// \return Einsum equation
/// ///
std::string get_equation() const { return m_equation; } std::string get_equation() const {
return m_equation;
}
/// \brief Check correctness of equation format and extract input subscripts /// \brief Check correctness of equation format and extract input subscripts
/// and output subscript /// and output subscript

View File

@ -7,18 +7,14 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Exponential Linear Unit /// \brief Exponential Linear Unit
/// x < 0 => f(x) = alpha * (exp(x) - 1.) /// x < 0 => f(x) = alpha * (exp(x) - 1.)
/// x >= 0 => f(x) = x /// x >= 0 => f(x) = x
/// ///
class NGRAPH_API Elu : public ngraph::op::Op class NGRAPH_API Elu : public ngraph::op::Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -32,10 +28,11 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
double get_alpha() const { return m_alpha; } double get_alpha() const {
return m_alpha;
}
private: private:
double m_alpha; double m_alpha;

View File

@ -7,18 +7,16 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Returns embeddings for given indices /// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingSegmentsSum : public Op class NGRAPH_API EmbeddingSegmentsSum : public Op {
{
public: public:
static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3}; static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs a EmbeddingSegmentsSum operation. /// \brief Constructs a EmbeddingSegmentsSum operation.
EmbeddingSegmentsSum() = default; EmbeddingSegmentsSum() = default;
/// \brief Constructs a EmbeddingSegmentsSum operation. /// \brief Constructs a EmbeddingSegmentsSum operation.
@ -63,10 +61,11 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; } bool visit_attributes(AttributeVisitor&) override {
return true;
}
private: private:
static constexpr int EMB_TABLE = 0; static constexpr int EMB_TABLE = 0;

View File

@ -8,18 +8,16 @@
#include "ngraph/op/util/embeddingbag_offsets_base.hpp" #include "ngraph/op/util/embeddingbag_offsets_base.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Returns embeddings for given indices /// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingBagOffsetsSum : public util::EmbeddingBagOffsetsBase class NGRAPH_API EmbeddingBagOffsetsSum : public util::EmbeddingBagOffsetsBase {
{
public: public:
static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsSum", 3}; static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsSum", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs a EmbeddingBagOffsetsSum operation. /// \brief Constructs a EmbeddingBagOffsetsSum operation.
EmbeddingBagOffsetsSum() = default; EmbeddingBagOffsetsSum() = default;
/// \brief Constructs a EmbeddingBagOffsetsSum operation. /// \brief Constructs a EmbeddingBagOffsetsSum operation.
@ -51,12 +49,9 @@ namespace ngraph
const Output<Node>& offsets, const Output<Node>& offsets,
const Output<Node>& default_index); const Output<Node>& default_index);
EmbeddingBagOffsetsSum(const Output<Node>& emb_table, EmbeddingBagOffsetsSum(const Output<Node>& emb_table, const Output<Node>& indices, const Output<Node>& offsets);
const Output<Node>& indices,
const Output<Node>& offsets);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v3 } // namespace v3
using v3::EmbeddingBagOffsetsSum; using v3::EmbeddingBagOffsetsSum;

View File

@ -8,18 +8,16 @@
#include "ngraph/op/util/embeddingbag_packed_base.hpp" #include "ngraph/op/util/embeddingbag_packed_base.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{
namespace v3
{
/// \brief Returns embeddings for given indices /// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingBagPackedSum : public util::EmbeddingBagPackedBase class NGRAPH_API EmbeddingBagPackedSum : public util::EmbeddingBagPackedBase {
{
public: public:
static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedSum", 3}; static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedSum", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs a EmbeddingBagPackedSum operation. /// \brief Constructs a EmbeddingBagPackedSum operation.
EmbeddingBagPackedSum() = default; EmbeddingBagPackedSum() = default;
/// \brief Constructs a EmbeddingBagPackedSum operation. /// \brief Constructs a EmbeddingBagPackedSum operation.
@ -42,8 +40,7 @@ namespace ngraph
EmbeddingBagPackedSum(const Output<Node>& emb_table, const Output<Node>& indices); EmbeddingBagPackedSum(const Output<Node>& emb_table, const Output<Node>& indices);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v3 } // namespace v3
using v3::EmbeddingBagPackedSum; using v3::EmbeddingBagPackedSum;

View File

@ -6,12 +6,9 @@
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
// clang-format off // clang-format off
/// \brief Elementwise is-equal operation. /// \brief Elementwise is-equal operation.
/// ///
@ -29,15 +26,11 @@ namespace ngraph
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | /// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on // clang-format on
class NGRAPH_API Equal : public util::BinaryElementwiseComparison class NGRAPH_API Equal : public util::BinaryElementwiseComparison {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an equal operation. /// \brief Constructs an equal operation.
Equal() Equal() : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY) {}
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an equal operation. /// \brief Constructs an equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
@ -45,15 +38,12 @@ namespace ngraph
/// \param auto_broadcast Auto broadcast specification /// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0, Equal(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v1 } // namespace v1

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise erf operation. /// \brief Elementwise erf operation.
class NGRAPH_API Erf : public util::UnaryElementwiseArithmetic class NGRAPH_API Erf : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
@ -25,10 +21,8 @@ namespace ngraph
Erf(const Output<Node>& arg); Erf(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise natural exponential (exp) operation. /// \brief Elementwise natural exponential (exp) operation.
class NGRAPH_API Exp : public util::UnaryElementwiseArithmetic class NGRAPH_API Exp : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -26,11 +22,9 @@ namespace ngraph
Exp(const Output<Node>& arg); Exp(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -6,27 +6,23 @@
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief An operation ExperimentalDetectronDetectionOutput performs /// \brief An operation ExperimentalDetectronDetectionOutput performs
/// non-maximum suppression to generate the detection output using /// non-maximum suppression to generate the detection output using
/// information on location and score predictions. /// information on location and score predictions.
class NGRAPH_API ExperimentalDetectronDetectionOutput : public Op class NGRAPH_API ExperimentalDetectronDetectionOutput : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation /// \brief Structure that specifies attributes of the operation
struct Attributes struct Attributes {
{
// specifies score threshold // specifies score threshold
float score_threshold; float score_threshold;
// specifies NMS threshold // specifies NMS threshold
@ -64,10 +60,11 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of the operation ExperimentalDetectronDetectionOutput /// \brief Returns attributes of the operation ExperimentalDetectronDetectionOutput
const Attributes& get_attrs() const { return m_attrs; } const Attributes& get_attrs() const {
return m_attrs;
}
private: private:
Attributes m_attrs; Attributes m_attrs;

View File

@ -6,26 +6,22 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief An operation ExperimentalDetectronGenerateProposalsSingleImage /// \brief An operation ExperimentalDetectronGenerateProposalsSingleImage
/// computes ROIs and their scores based on input data. /// computes ROIs and their scores based on input data.
class NGRAPH_API ExperimentalDetectronGenerateProposalsSingleImage : public Op class NGRAPH_API ExperimentalDetectronGenerateProposalsSingleImage : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation /// \brief Structure that specifies attributes of the operation
struct Attributes struct Attributes {
{
// minimum box width & height // minimum box width & height
float min_size; float min_size;
// specifies NMS threshold // specifies NMS threshold
@ -54,10 +50,11 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
const Attributes& get_attrs() const { return m_attrs; } const Attributes& get_attrs() const {
return m_attrs;
}
private: private:
Attributes m_attrs; Attributes m_attrs;

View File

@ -6,26 +6,22 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief An operation ExperimentalDetectronPriorGridGenerator generates prior /// \brief An operation ExperimentalDetectronPriorGridGenerator generates prior
/// grids of specified sizes. /// grids of specified sizes.
class NGRAPH_API ExperimentalDetectronPriorGridGenerator : public Op class NGRAPH_API ExperimentalDetectronPriorGridGenerator : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation /// \brief Structure that specifies attributes of the operation
struct Attributes struct Attributes {
{
// Specifies whether the output tensor should be 2D or 4D // Specifies whether the output tensor should be 2D or 4D
// `true` means the output tensor should be 2D tensor, // `true` means the output tensor should be 2D tensor,
// `false` means the output tensor should be 4D tensor. // `false` means the output tensor should be 4D tensor.
@ -55,10 +51,11 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of this operation. /// \brief Returns attributes of this operation.
const Attributes& get_attrs() const { return m_attrs; } const Attributes& get_attrs() const {
return m_attrs;
}
private: private:
Attributes m_attrs; Attributes m_attrs;

View File

@ -7,26 +7,22 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief An operation ExperimentalDetectronROIFeatureExtractor /// \brief An operation ExperimentalDetectronROIFeatureExtractor
/// is the ROIAlign operation applied over a feature pyramid. /// is the ROIAlign operation applied over a feature pyramid.
class NGRAPH_API ExperimentalDetectronROIFeatureExtractor : public Op class NGRAPH_API ExperimentalDetectronROIFeatureExtractor : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation /// \brief Structure that specifies attributes of the operation
struct Attributes struct Attributes {
{
int64_t output_size; int64_t output_size;
int64_t sampling_ratio; int64_t sampling_ratio;
std::vector<int64_t> pyramid_scales; std::vector<int64_t> pyramid_scales;
@ -38,23 +34,22 @@ namespace ngraph
/// ///
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor /// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
/// \param attrs Operation attributes /// \param attrs Operation attributes
ExperimentalDetectronROIFeatureExtractor(const OutputVector& args, ExperimentalDetectronROIFeatureExtractor(const OutputVector& args, const Attributes& attrs);
const Attributes& attrs);
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation. /// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
/// ///
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor /// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
/// \param attrs Operation attributes /// \param attrs Operation attributes
ExperimentalDetectronROIFeatureExtractor(const NodeVector& args, ExperimentalDetectronROIFeatureExtractor(const NodeVector& args, const Attributes& attrs);
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of the operation. /// \brief Returns attributes of the operation.
const Attributes& get_attrs() const { return m_attrs; } const Attributes& get_attrs() const {
return m_attrs;
}
private: private:
Attributes m_attrs; Attributes m_attrs;

View File

@ -6,20 +6,17 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief An operation ExperimentalDetectronTopKROIs, according to the repository /// \brief An operation ExperimentalDetectronTopKROIs, according to the repository
/// is TopK operation applied to probabilities of input ROIs. /// is TopK operation applied to probabilities of input ROIs.
class NGRAPH_API ExperimentalDetectronTopKROIs : public Op class NGRAPH_API ExperimentalDetectronTopKROIs : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -29,17 +26,16 @@ namespace ngraph
/// \param input_rois Input rois /// \param input_rois Input rois
/// \param rois_probs Probabilities for input rois /// \param rois_probs Probabilities for input rois
/// \param max_rois Maximal numbers of output rois /// \param max_rois Maximal numbers of output rois
ExperimentalDetectronTopKROIs(const Output<Node>& input_rois, ExperimentalDetectronTopKROIs(const Output<Node>& input_rois, const Output<Node>& rois_probs, size_t max_rois = 0);
const Output<Node>& rois_probs,
size_t max_rois = 0);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_max_rois() const { return m_max_rois; } size_t get_max_rois() const {
return m_max_rois;
}
private: private:
size_t m_max_rois; size_t m_max_rois;

View File

@ -6,14 +6,10 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v3 {
{ class NGRAPH_API ExtractImagePatches : public Op {
namespace v3
{
class NGRAPH_API ExtractImagePatches : public Op
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -36,17 +32,32 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
const Shape& get_sizes() const { return m_patch_sizes; } const Shape& get_sizes() const {
void set_sizes(const Shape& sizes) { m_patch_sizes = sizes; } return m_patch_sizes;
const Strides& get_strides() const { return m_patch_movement_strides; } }
void set_strides(const Strides& strides) { m_patch_movement_strides = strides; } void set_sizes(const Shape& sizes) {
const Shape& get_rates() const { return m_patch_selection_rates; } m_patch_sizes = sizes;
void set_rates(const Shape& rates) { m_patch_selection_rates = rates; } }
const PadType& get_auto_pad() const { return m_padding; } const Strides& get_strides() const {
void set_auto_pad(PadType& padding) { m_padding = padding; } return m_patch_movement_strides;
}
void set_strides(const Strides& strides) {
m_patch_movement_strides = strides;
}
const Shape& get_rates() const {
return m_patch_selection_rates;
}
void set_rates(const Shape& rates) {
m_patch_selection_rates = rates;
}
const PadType& get_auto_pad() const {
return m_padding;
}
void set_auto_pad(PadType& padding) {
m_padding = padding;
}
private: private:
Shape m_patch_sizes; Shape m_patch_sizes;

View File

@ -8,12 +8,9 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// ///
/// \brief Class performing element-wise linear quantization. /// \brief Class performing element-wise linear quantization.
/// ///
@ -27,8 +24,7 @@ namespace ngraph
/// (levels-1) * (output_high - output_low) + output_low /// (levels-1) * (output_high - output_low) + output_low
/// ///
/// ///
class NGRAPH_API FakeQuantize : public ngraph::op::Op class NGRAPH_API FakeQuantize : public ngraph::op::Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -51,20 +47,23 @@ namespace ngraph
const Output<Node>& output_low, const Output<Node>& output_low,
const Output<Node>& output_high, const Output<Node>& output_high,
std::size_t levels, std::size_t levels,
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual void validate_and_infer_types() override; virtual void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
std::size_t get_levels() const { return m_levels; } std::size_t get_levels() const {
void set_levels(std::size_t levels) { m_levels = levels; } return m_levels;
const AutoBroadcastSpec& get_auto_broadcast() const { return m_auto_broadcast; } }
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast) void set_levels(std::size_t levels) {
{ m_levels = levels;
}
const AutoBroadcastSpec& get_auto_broadcast() const {
return m_auto_broadcast;
}
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast) {
m_auto_broadcast = auto_broadcast; m_auto_broadcast = auto_broadcast;
} }

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v0 {
{
namespace v0
{
/// \brief Elementwise floor operation. /// \brief Elementwise floor operation.
class NGRAPH_API Floor : public util::UnaryElementwiseArithmetic class NGRAPH_API Floor : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
@ -25,10 +21,8 @@ namespace ngraph
Floor(const Output<Node>& arg); Floor(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v0 } // namespace v0

View File

@ -8,22 +8,17 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Elementwise FloorMod operation. /// \brief Elementwise FloorMod operation.
/// ///
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an uninitialized addition operation /// \brief Constructs an uninitialized addition operation
FloorMod() FloorMod() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
/// \brief Constructs an Floor Mod operation. /// \brief Constructs an Floor Mod operation.
/// ///
@ -39,11 +34,9 @@ namespace ngraph
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
}; };
} // namespace v1 } // namespace v1

View File

@ -6,15 +6,11 @@
#include "ngraph/op/util/gather_base.hpp" #include "ngraph/op/util/gather_base.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Gather slices from axis of params according to indices /// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase class NGRAPH_API Gather : public op::util::GatherBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max(); static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
@ -22,23 +18,18 @@ namespace ngraph
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather /// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from /// \param axis The tensor is a dimension index to gather data from
Gather(const Output<Node>& params, Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axis);
const Output<Node>& indices,
const Output<Node>& axis);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
int64_t get_axis() const override; int64_t get_axis() const override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v1 } // namespace v1
namespace v7 namespace v7 {
{
/// \brief Gather slices from axis of params according to indices /// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase class NGRAPH_API Gather : public op::util::GatherBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
Gather() = default; Gather() = default;
@ -57,16 +48,13 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
int64_t get_batch_dims() const; int64_t get_batch_dims() const;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v7 } // namespace v7
namespace v8 namespace v8 {
{
/// \brief Gather slices from axis of params according to indices /// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase class NGRAPH_API Gather : public op::util::GatherBase {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
Gather() = default; Gather() = default;
@ -84,8 +72,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
int64_t get_batch_dims() const; int64_t get_batch_dims() const;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v8 } // namespace v8
} // namespace op } // namespace op

View File

@ -6,16 +6,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v6 {
{
namespace v6
{
/// \brief GatherElements operation /// \brief GatherElements operation
/// ///
class NGRAPH_API GatherElements : public Op class NGRAPH_API GatherElements : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
GatherElements() = default; GatherElements() = default;
@ -25,16 +21,15 @@ namespace ngraph
/// \param data Node producing data that are gathered /// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements /// \param indices Node producing indices by which the operation gathers elements
/// \param axis specifies axis along which indices are specified /// \param axis specifies axis along which indices are specified
GatherElements(const Output<Node>& data, GatherElements(const Output<Node>& data, const Output<Node>& indices, const int64_t axis);
const Output<Node>& indices,
const int64_t axis);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_axis() const { return m_axis; } int64_t get_axis() const {
return m_axis;
}
private: private:
int64_t m_axis; int64_t m_axis;

View File

@ -6,16 +6,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v5 {
{
namespace v5
{
/// \brief GatherND operation /// \brief GatherND operation
/// ///
class NGRAPH_API GatherND : public Op class NGRAPH_API GatherND : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
GatherND() = default; GatherND() = default;
@ -26,16 +22,15 @@ namespace ngraph
/// \param indices Node producing indices by which the operation gathers elements /// \param indices Node producing indices by which the operation gathers elements
/// or slices from data /// or slices from data
/// \param batch_dims Specifies a number of batch dimensions /// \param batch_dims Specifies a number of batch dimensions
GatherND(const Output<Node>& data, GatherND(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims = 0);
const Output<Node>& indices,
const size_t batch_dims = 0);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_batch_dims() const { return m_batch_dims; } size_t get_batch_dims() const {
return m_batch_dims;
}
private: private:
size_t m_batch_dims; size_t m_batch_dims;

View File

@ -6,16 +6,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op namespace v1 {
{
namespace v1
{
/// \brief Generates the complete beams from the ids per each step and the parent beam /// \brief Generates the complete beams from the ids per each step and the parent beam
/// ids. /// ids.
class NGRAPH_API GatherTree : public Op class NGRAPH_API GatherTree : public Op {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -35,8 +31,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v1 } // namespace v1
} // namespace op } // namespace op

View File

@ -9,20 +9,18 @@
#include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph {
{ namespace op {
namespace op
{
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
namespace v0 namespace v0 {
{
/// \brief Gaussian Error Linear Unit /// \brief Gaussian Error Linear Unit
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) /// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
class NGRAPH_API Gelu : public ngraph::op::util::FusedOp class NGRAPH_API Gelu : public ngraph::op::util::FusedOp {
{
public: public:
static constexpr NodeTypeInfo type_info{"Gelu", 0}; static constexpr NodeTypeInfo type_info{"Gelu", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override {
return type_info;
}
Gelu(); Gelu();
/// \brief Constructs a Gelu operation. /// \brief Constructs a Gelu operation.
/// ///
@ -34,29 +32,22 @@ namespace ngraph
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
}; };
} // namespace v0 } // namespace v0
using v0::Gelu; using v0::Gelu;
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
/// \brief Specifies the approximation to calculate Gelu /// \brief Specifies the approximation to calculate Gelu
enum class GeluApproximationMode enum class GeluApproximationMode { TANH, ERF };
{
TANH,
ERF
};
NGRAPH_API std::ostream& operator<<(std::ostream& s, const GeluApproximationMode& type); NGRAPH_API std::ostream& operator<<(std::ostream& s, const GeluApproximationMode& type);
namespace v7 namespace v7 {
{
/// \brief Gaussian Error Linear Unit /// \brief Gaussian Error Linear Unit
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) for "approximation" = "erf" /// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) for "approximation" = "erf"
/// f(x) = 0.5 * x * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]) for "approximation" = /// f(x) = 0.5 * x * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]) for "approximation" =
/// "tanh" /// "tanh"
class NGRAPH_API Gelu : public util::UnaryElementwiseArithmetic class NGRAPH_API Gelu : public util::UnaryElementwiseArithmetic {
{
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
@ -65,19 +56,16 @@ namespace ngraph
/// ///
/// \param data Input tensor /// \param data Input tensor
/// \param mode Approximation mode /// \param mode Approximation mode
Gelu(const Output<Node>& data, Gelu(const Output<Node>& data, GeluApproximationMode mode = GeluApproximationMode::ERF);
GeluApproximationMode mode = GeluApproximationMode::ERF);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs, bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
const HostTensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
std::shared_ptr<Node> std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
GeluApproximationMode get_approximation_mode() const; GeluApproximationMode get_approximation_mode() const;
@ -88,16 +76,13 @@ namespace ngraph
} // namespace op } // namespace op
template <> template <>
class NGRAPH_API AttributeAdapter<op::GeluApproximationMode> class NGRAPH_API AttributeAdapter<op::GeluApproximationMode>
: public EnumAttributeAdapterBase<op::GeluApproximationMode> : public EnumAttributeAdapterBase<op::GeluApproximationMode> {
{
public: public:
AttributeAdapter(op::GeluApproximationMode& value) AttributeAdapter(op::GeluApproximationMode& value) : EnumAttributeAdapterBase<op::GeluApproximationMode>(value) {}
: EnumAttributeAdapterBase<op::GeluApproximationMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::GeluApproximationMode>", static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::GeluApproximationMode>", 0};
0}; const DiscreteTypeInfo& get_type_info() const override {
const DiscreteTypeInfo& get_type_info() const override { return type_info; } return type_info;
}
}; };
} // namespace ngraph } // namespace ngraph

Some files were not shown because too many files have changed in this diff Show More