Changed nGraph code style to Google (#6926)

* Changed clang-format

* Fixed code style for tests

* Fixed build

* Fixed code style
This commit is contained in:
Ilya Churaev 2021-08-13 05:28:28 +03:00 committed by GitHub
parent 273c7188a4
commit 39131968c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1637 changed files with 84233 additions and 112281 deletions

View File

@ -1,54 +1,28 @@
BasedOnStyle: LLVM
BasedOnStyle: Google
IndentWidth: 4
UseTab: Never
ColumnLimit: 120
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -4
AlignConsecutiveDeclarations: false
AlignConsecutiveAssignments: false
AlignTrailingComments: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
AlignConsecutiveMacros: true
AllowAllArgumentsOnNextLine: false
AllowAllConstructorInitializersOnNextLine: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Empty
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: false
BinPackArguments: false
BinPackParameters: false
BreakBeforeBraces: Allman
BreakConstructorInitializersBeforeComma: true
ColumnLimit: 100
IndentCaseLabels: false
IndentWrappedFunctionNames: true
KeepEmptyLinesAtTheStartOfBlocks: false
NamespaceIndentation: All
PointerAlignment: Left
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SortIncludes: false
ReflowComments: true
IncludeCategories:
- Regex: '^".*'
Priority: 3
- Regex: '^<.*'
Priority: 2
SortIncludes: true
CommentPragmas: '^#'
DerivePointerAlignment: false
FixNamespaceComments: true
IndentCaseLabels: false
IndentPPDirectives: AfterHash
ForEachMacros:
- foreach
- FOREACH_CHILD

View File

@ -12,235 +12,218 @@
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
namespace ngraph
{
namespace builder
{
class numpy_autobroadcast_incompatible_shapes : public ngraph::ngraph_error
{
public:
numpy_autobroadcast_incompatible_shapes(const ngraph::Shape& shape1,
const ngraph::Shape& shape2);
namespace ngraph {
namespace builder {
class numpy_autobroadcast_incompatible_shapes : public ngraph::ngraph_error {
public:
numpy_autobroadcast_incompatible_shapes(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
private:
const ngraph::Shape m_shape1;
const ngraph::Shape m_shape2;
private:
const ngraph::Shape m_shape1;
const ngraph::Shape m_shape2;
static std::string error_str(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
};
static std::string error_str(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
};
///
/// \brief Broadcast all values, if necessary, to obtain equal shapes according
/// to NumPy's auto-broadcasting scheme.
///
/// \note There are some shape combinations which the autobroadcast algoritm cannot
/// handle. An exception is thrown when such combinations are provided to this
/// function.
///
/// \param values Vector of output values.
///
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
///
/// \return Vector of broadcasted values.
///
OutputVector numpy_broadcast_outputs(const OutputVector& values);
///
/// \brief Broadcast all values, if necessary, to obtain equal shapes according
/// to NumPy's auto-broadcasting scheme.
///
/// \note There are some shape combinations which the autobroadcast algoritm cannot
/// handle. An exception is thrown when such combinations are provided to this
/// function.
///
/// \param values Vector of output values.
///
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
///
/// \return Vector of broadcasted values.
///
OutputVector numpy_broadcast_outputs(const OutputVector& values);
///
/// \brief Broadcast input value to provided shape using NumPy's auto-broadcasting
/// rules.
///
/// \param value Input value
/// \param shape Requested output shape
///
/// \return Node producing values with requested shape.
///
std::shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape);
///
/// \brief Broadcast input value to provided shape using NumPy's auto-broadcasting
/// rules.
///
/// \param value Input value
/// \param shape Requested output shape
///
/// \return Node producing values with requested shape.
///
std::shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape);
/// \brief Wrap two graph values, if necessary, to obtain values with identical shapes,
/// using NumPy's auto-broadcast rules.
///
/// The elements in the std::pair returned by this function correspond to those supplied
/// in the std::pair provided via \p args.
///
/// If \p args.first and \p args.second produce identical shapes, then the returned
/// std::pair will have the same value as \p args.
///
/// If \p args.first and \p args.second produce different shapes, then this function creates
/// new ngraph::op::Reshape and/or ngraph::op::Broadcast nodes, as needed, to wrap
/// \p args.first and/or \p args.second in a manner that yields values with the same shape.
///
/// There are some shape combinations which the autobroadcast algoritm cannot handle.
/// An exception is thrown when such combinations are provided to this function.
///
/// \pre
/// - \p args.first is not null
/// - \p args.second is not null
///
/// \post
/// - The ngraph::Node objects pointed to by \p args.first and \p args.second have not been
/// altered by this function, except by possibly having added consumers of their values.
///
/// - If an exception was not thrown, then the return value's \p first and \p second
/// elements point to ngraph::Node objects whose output values have the same shape.
///
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args);
/// \brief Wrap two graph values, if necessary, to obtain values with identical shapes,
/// using NumPy's auto-broadcast rules.
///
/// The elements in the std::pair returned by this function correspond to those supplied
/// in the std::pair provided via \p args.
///
/// If \p args.first and \p args.second produce identical shapes, then the returned
/// std::pair will have the same value as \p args.
///
/// If \p args.first and \p args.second produce different shapes, then this function creates
/// new ngraph::op::Reshape and/or ngraph::op::Broadcast nodes, as needed, to wrap
/// \p args.first and/or \p args.second in a manner that yields values with the same shape.
///
/// There are some shape combinations which the autobroadcast algoritm cannot handle.
/// An exception is thrown when such combinations are provided to this function.
///
/// \pre
/// - \p args.first is not null
/// - \p args.second is not null
///
/// \post
/// - The ngraph::Node objects pointed to by \p args.first and \p args.second have not been
/// altered by this function, except by possibly having added consumers of their values.
///
/// - If an exception was not thrown, then the return value's \p first and \p second
/// elements point to ngraph::Node objects whose output values have the same shape.
///
/// \exception ngraph::builder::numpy_autobroadcast_incompatible_shapes
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> numpy_broadcast(
const std::pair<Output<Node>, Output<Node>>& args);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul`
/// operation.
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html)
/// This mean that only \"stack of matrices\" axes are bidirectionally
/// broadcasted. The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix
/// multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix
/// multiplication.
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul`
/// operation.
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html)
/// This mean that only \"stack of matrices\" axes are bidirectionally
/// broadcasted. The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix
/// multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix
/// multiplication.
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, const Output<Node>& right);
/// \brief Cast shape of all input nodes for an element-wise operation that requires
/// shape-compatibility
///
/// \param inputs Original list of inputs
/// \param axis Index starting to align
///
/// \return pdpd-style broadcasted list of nodes.
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis);
/// \brief Cast shape of all input nodes for an element-wise operation that requires
/// shape-compatibility
///
/// \param inputs Original list of inputs
/// \param axis Index starting to align
///
/// \return pdpd-style broadcasted list of nodes.
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
/// \param start_match_axis The axis along which we want to replicate elements.
/// The starting axis position (0-based) int the output
/// shape from which the current shape of the tensor
/// matches the desired new shape.
///
/// \return The indices of added axes.
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
/// \param start_match_axis The axis along which we want to replicate elements.
/// The starting axis position (0-based) int the output
/// shape from which the current shape of the tensor
/// matches the desired new shape.
///
/// \return The indices of added axes.
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
///
/// \brief Calculate the output shape of numpy-style broadcast operation for all input
/// shapes.
///
/// This function finds the maximum tensor shape that will be the result of
/// element-wise operation that will be applied to the input shapes vector.
/// The function also prepares the shape of each input for the element-wise
/// operation by left-padding those shapes so that their rank is equal to the
/// left_shape's rank.
///
/// \param input_shapes A vector of input shapes for which a common shape should be
/// found
///
/// \return A pair that contains the target shape as its first object and a vector of
/// padded input shapes ready to be broadcasted as the second object
///
std::pair<Shape, std::vector<Shape>>
get_numpy_broadcast_shapes(const std::vector<Shape>& input_shapes);
///
/// \brief Calculate the output shape of numpy-style broadcast operation for all input
/// shapes.
///
/// This function finds the maximum tensor shape that will be the result of
/// element-wise operation that will be applied to the input shapes vector.
/// The function also prepares the shape of each input for the element-wise
/// operation by left-padding those shapes so that their rank is equal to the
/// left_shape's rank.
///
/// \param input_shapes A vector of input shapes for which a common shape should be
/// found
///
/// \return A pair that contains the target shape as its first object and a vector of
/// padded input shapes ready to be broadcasted as the second object
///
std::pair<Shape, std::vector<Shape>> get_numpy_broadcast_shapes(const std::vector<Shape>& input_shapes);
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
const Shape& new_shape,
std::size_t start_match_axis)
{
auto shape_const =
op::Constant::create(element::u64, Shape{new_shape.size()}, new_shape);
return std::make_shared<op::v1::Broadcast>(
value,
shape_const,
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
}
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
const Shape& new_shape,
std::size_t start_match_axis) {
auto shape_const = op::Constant::create(element::u64, Shape{new_shape.size()}, new_shape);
return std::make_shared<op::v1::Broadcast>(
value,
shape_const,
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
}
namespace opset1
{
///
/// \brief Broadcast right node to left node's shape using legacy scheme.
///
/// \param[in] left The left hand side node of binary operation.
/// \param[in] right The right hand side node of binary operation. The one
/// to be broadcasted.
/// \param[in] start_match_axis The axis index starting mutually equal shapes
/// of both nodes.
///
/// \return The Output object connected to node producing broadcasted right node.
///
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
namespace opset1 {
///
/// \brief Broadcast right node to left node's shape using legacy scheme.
///
/// \param[in] left The left hand side node of binary operation.
/// \param[in] right The right hand side node of binary operation. The one
/// to be broadcasted.
/// \param[in] start_match_axis The axis index starting mutually equal shapes
/// of both nodes.
///
/// \return The Output object connected to node producing broadcasted right node.
///
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
///
/// \brief Reconstructs axes mapping vector for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The vector with axes indexes mapping .
///
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape,
const AxisSet& broadcast_axes);
///
/// \brief Reconstructs axes mapping vector for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The vector with axes indexes mapping .
///
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape, const AxisSet& broadcast_axes);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical as the output shape.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical as the output shape.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape, const Shape& input_shape, std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast operation.
/// \note Shapes' ranks need to be static.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical to consecutive subset of output shape
/// dimensions.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const PartialShape& input_shape,
std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast operation.
/// \note Shapes' ranks need to be static.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical to consecutive subset of output shape
/// dimensions.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const PartialShape& input_shape,
std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The Output object with Node returning axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const AxisSet& broadcast_axes);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The Output object with Node returning axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape, const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
std::size_t start_match_axis);
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, std::size_t start_match_axis);
} // namespace opset1
} // namespace builder
} // namespace ngraph
} // namespace opset1
} // namespace builder
} // namespace ngraph

View File

@ -10,116 +10,117 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/type/float16.hpp"
namespace ngraph
{
namespace builder
{
template <class T>
std::shared_ptr<Node>
make_constant(const element::Type& type, const Shape& shape, const T& num)
{
std::shared_ptr<Node> val = nullptr;
namespace ngraph {
namespace builder {
template <class T>
std::shared_ptr<Node> make_constant(const element::Type& type, const Shape& shape, const T& num) {
std::shared_ptr<Node> val = nullptr;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
# pragma GCC diagnostic push
# pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (type)
{
case element::Type_t::f32:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
break;
case element::Type_t::f64:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)});
break;
case element::Type_t::f16:
val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))});
break;
case element::Type_t::bf16:
val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
break;
case element::Type_t::i64:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)});
break;
case element::Type_t::i32:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)});
break;
case element::Type_t::i16:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)});
break;
case element::Type_t::i8:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)});
break;
case element::Type_t::u64:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)});
break;
case element::Type_t::u32:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)});
break;
case element::Type_t::u16:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)});
break;
case element::Type_t::u8:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)});
break;
case element::Type_t::dynamic:
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
case element::Type_t::boolean:
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::u1:
throw ngraph_error("make_constant: Unsupported element type 'u1'");
case element::Type_t::i4:
throw ngraph_error("make_constant: Unsupported element type 'i4'");
case element::Type_t::u4:
throw ngraph_error("make_constant: Unsupported element type 'u4'");
case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
}
switch (type) {
case element::Type_t::f32:
val =
std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
break;
case element::Type_t::f64:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<double>{static_cast<double>(num)});
break;
case element::Type_t::f16:
val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))});
break;
case element::Type_t::bf16:
val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
break;
case element::Type_t::i64:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<int64_t>{static_cast<int64_t>(num)});
break;
case element::Type_t::i32:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<int32_t>{static_cast<int32_t>(num)});
break;
case element::Type_t::i16:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<int16_t>{static_cast<int16_t>(num)});
break;
case element::Type_t::i8:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<int8_t>{static_cast<int8_t>(num)});
break;
case element::Type_t::u64:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<uint64_t>{static_cast<uint64_t>(num)});
break;
case element::Type_t::u32:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<uint32_t>{static_cast<uint32_t>(num)});
break;
case element::Type_t::u16:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<uint16_t>{static_cast<uint16_t>(num)});
break;
case element::Type_t::u8:
val = std::make_shared<ngraph::op::Constant>(type,
ngraph::Shape{},
std::vector<uint8_t>{static_cast<uint8_t>(num)});
break;
case element::Type_t::dynamic:
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
case element::Type_t::boolean:
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::u1:
throw ngraph_error("make_constant: Unsupported element type 'u1'");
case element::Type_t::i4:
throw ngraph_error("make_constant: Unsupported element type 'i4'");
case element::Type_t::u4:
throw ngraph_error("make_constant: Unsupported element type 'u4'");
case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
# pragma GCC diagnostic pop
#endif
if (shape.size() > 0)
{
ngraph::AxisSet axes;
for (size_t i = 0; i < shape.size(); i++)
{
axes.insert(i);
}
val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
}
return val->add_provenance_group_members_above({});
if (shape.size() > 0) {
ngraph::AxisSet axes;
for (size_t i = 0; i < shape.size(); i++) {
axes.insert(i);
}
val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
}
/// \brief Create constant filled with double value
///
/// \note If num value exeeds capacity of type, the value is clamped.
///
/// \param[in] type The type of produced Constant node.
/// \param[in] shape The shape of produced Constant node.
/// \param[in] num The value used to fill Constant node.
///
/// \return The Constant node which have expected type, shape and value.
///
std::shared_ptr<Node>
make_constant_from_double(const element::Type& type, const Shape& shape, double num);
} // namespace builder
} // namespace ngraph
return val->add_provenance_group_members_above({});
}
/// \brief Create constant filled with double value
///
/// \note If num value exeeds capacity of type, the value is clamped.
///
/// \param[in] type The type of produced Constant node.
/// \param[in] shape The shape of produced Constant node.
/// \param[in] num The value used to fill Constant node.
///
/// \return The Constant node which have expected type, shape and value.
///
std::shared_ptr<Node> make_constant_from_double(const element::Type& type, const Shape& shape, double num);
} // namespace builder
} // namespace ngraph

View File

@ -10,86 +10,80 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace builder
{
/// \brief Specifies method of bias application to avoid numerical problems
enum class BiasMode
{
// Add bias to intermediate result
ADD,
// Calculate max of intermediate result and bias
MAX
};
namespace ngraph {
namespace builder {
/// \brief Specifies method of bias application to avoid numerical problems
enum class BiasMode {
// Add bias to intermediate result
ADD,
// Calculate max of intermediate result and bias
MAX
};
namespace opset1
{
/// \brief Calculates L-0 norm of input tensor.
///
/// \note The L-0 norm represents the cardinality of elements different
/// from zero. This actually is not a "true" norm.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims = false);
namespace opset1 {
/// \brief Calculates L-0 norm of input tensor.
///
/// \note The L-0 norm represents the cardinality of elements different
/// from zero. This actually is not a "true" norm.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l0_norm(const Output<Node>& value, const Output<Node>& reduction_axes, bool keep_dims = false);
/// \brief Calculates L-1 norm of a value.
///
/// \note The L-1 norm represents the sum of absolute values.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-1 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias = 0.f,
bool keep_dims = false);
/// \brief Calculates L-1 norm of a value.
///
/// \note The L-1 norm represents the sum of absolute values.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-1 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias = 0.f,
bool keep_dims = false);
/// \brief Calculates L-2 norm of input tensor.
///
/// \note The L-2 norm represents the square root of sum of squares of each
/// individual element.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias combined with calculated sum.
/// \param[in] bias_mode The method of bias application.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-2 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l2_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias = 0.f,
BiasMode bias_mode = BiasMode::ADD,
bool keep_dims = false);
/// \brief Calculates L-2 norm of input tensor.
///
/// \note The L-2 norm represents the square root of sum of squares of each
/// individual element.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias combined with calculated sum.
/// \param[in] bias_mode The method of bias application.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-2 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l2_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias = 0.f,
BiasMode bias_mode = BiasMode::ADD,
bool keep_dims = false);
/// \brief Creates node which calculates L-p norm on input tensor.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-p norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
std::size_t p_norm = 2,
float bias = 0.f,
bool keep_dims = false);
} // namespace opset1
} // namespace builder
} // namespace ngraph
/// \brief Creates node which calculates L-p norm on input tensor.
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-p norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
std::size_t p_norm = 2,
float bias = 0.f,
bool keep_dims = false);
} // namespace opset1
} // namespace builder
} // namespace ngraph

View File

@ -7,13 +7,10 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace builder
{
namespace opset1
{
// clang-format off
namespace ngraph {
namespace builder {
namespace opset1 {
// clang-format off
/// \brief Sum-based Mean of a Tensor.
///
/// Calculates
@ -35,16 +32,12 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
// clang-format on
std::shared_ptr<Node> mean(const Output<Node>& node,
const AxisSet& reduction_axes,
bool keep_dims = false);
// clang-format on
std::shared_ptr<Node> mean(const Output<Node>& node, const AxisSet& reduction_axes, bool keep_dims = false);
std::shared_ptr<Node> mean(const Output<Node>& node,
const Output<Node>& reduction_axes,
bool keep_dims = false);
std::shared_ptr<Node> mean(const Output<Node>& node, const Output<Node>& reduction_axes, bool keep_dims = false);
// clang-format off
// clang-format off
/// \brief Sum-based Variance of a Tensor.
///
/// If bessel_correct is true, calculates
@ -70,16 +63,16 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
// clang-format on
std::shared_ptr<Node> variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction = false);
// clang-format on
std::shared_ptr<Node> variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction = false);
std::shared_ptr<Node> variance(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims = false,
bool bessel_correction = false);
} // namespace opset1
std::shared_ptr<Node> variance(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims = false,
bool bessel_correction = false);
} // namespace opset1
} // namespace builder
} // namespace ngraph
} // namespace builder
} // namespace ngraph

View File

@ -11,77 +11,70 @@
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace builder
{
namespace opset1
{
/// \brief Change shape of a value
///
/// \param[in] value The value to be reshaped.
/// \param[in] shape The new shape.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> reshape(const Output<Node>& value, const Shape& shape);
namespace ngraph {
namespace builder {
namespace opset1 {
/// \brief Change shape of a value
///
/// \param[in] value The value to be reshaped.
/// \param[in] shape The new shape.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> reshape(const Output<Node>& value, const Shape& shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param The vlaue whose axes we want to permute.
/// \param axes_order The permutation of axes.
///
/// \return Transpose:v1 op.
std::shared_ptr<Node> reorder_axes(const Output<Node>& value,
std::vector<size_t> axes_order = {});
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param The vlaue whose axes we want to permute.
/// \param axes_order The permutation of axes.
///
/// \return Transpose:v1 op.
std::shared_ptr<Node> reorder_axes(const Output<Node>& value, std::vector<size_t> axes_order = {});
/// \brief Return transposed value (with axes in reversed order).
///
/// \param Value to transpose.
///
/// \return Transpose:v1 op.
std::shared_ptr<Node> transpose(const Output<Node>& value);
/// \brief Return transposed value (with axes in reversed order).
///
/// \param Value to transpose.
///
/// \return Transpose:v1 op.
std::shared_ptr<Node> transpose(const Output<Node>& value);
/// \brief Flatten a value into a 2D matrix, with a static dividing axis.
///
/// \param The tensor to be flattened.
/// \param The axis dividing shape.
///
/// \return The new value will be a 2D matrix representing the flattened input
/// node.
std::shared_ptr<Node> flatten(const Output<Node>& value, int axis);
/// \brief Flatten a value into a 2D matrix, with a static dividing axis.
///
/// \param The tensor to be flattened.
/// \param The axis dividing shape.
///
/// \return The new value will be a 2D matrix representing the flattened input
/// node.
std::shared_ptr<Node> flatten(const Output<Node>& value, int axis);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] value The value to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] value The value to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] value The value to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::size_t> axes = {0});
/// \brief Remove empty axes from input tensor.
///
/// \param[in] value The value to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value, std::vector<std::size_t> axes = {0});
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] value The value to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std::shared_ptr<Node> collapse(const Output<Node>& value,
const std::size_t start_axis,
const std::size_t end_axis);
} // namespace opset1
} // namespace builder
} // namespace ngraph
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] value The value to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std::shared_ptr<Node> collapse(const Output<Node>& value, const std::size_t start_axis, const std::size_t end_axis);
} // namespace opset1
} // namespace builder
} // namespace ngraph

View File

@ -7,76 +7,69 @@
#include "ngraph/node.hpp"
namespace ngraph
{
namespace builder
{
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to be split.
/// \param length_parts The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
int64_t axis = 0);
namespace ngraph {
namespace builder {
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to be split.
/// \param length_parts The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value, const std::vector<size_t>& length_parts, int64_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param value The value to split.
/// \param split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple outputs we split input node into.
///
NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param value The value to split.
/// \param split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple outputs we split input node into.
///
NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
namespace opset1
{
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to be split.
/// \param split_lengths The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple outputs we split input node into.
/// The vector is output of Split:v1 op
///
OutputVector split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis = 0);
namespace opset1 {
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to be split.
/// \param split_lengths The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple outputs we split input node into.
/// The vector is output of Split:v1 op
///
OutputVector split(const Output<Node>& value, const std::vector<size_t>& split_lengths, int64_t axis = 0);
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to split.
/// \param num_splits The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of VariadicSplit:v1 op
///
OutputVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
} // namespace opset1
} // namespace builder
} // namespace ngraph
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to split.
/// \param num_splits The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of VariadicSplit:v1 op
///
OutputVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
} // namespace opset1
} // namespace builder
} // namespace ngraph

View File

@ -18,470 +18,381 @@
using namespace std;
namespace ngraph
{
namespace builder
{
numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(
const Shape& shape1, const Shape& shape2)
: ngraph_error(error_str(shape1, shape2))
, m_shape1(shape1)
, m_shape2(shape2)
{
namespace ngraph {
namespace builder {
numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(const Shape& shape1,
const Shape& shape2)
: ngraph_error(error_str(shape1, shape2)),
m_shape1(shape1),
m_shape2(shape2) {}
string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1, const Shape& shape2) {
ostringstream os;
os << "Auto-broadcast not possible for these input shapes:"
<< " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
return os.str();
}
///
/// \brief Calculate the output shape of numpy-style broadcast operation for two
/// shapes.
///
/// \note More info:
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
/// Example: left: [3, 1, 10] right: [5, 1] return: [3, 5, 10]
///
/// \param lhs_shape First input shape.
/// \param rhs_shape Second input Shape.
///
/// \return Broadcast shape of input shapes.
///
static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape) {
Shape result;
auto lhs_rank = lhs_shape.size();
auto rhs_rank = rhs_shape.size();
auto max_rank = max(lhs_rank, rhs_rank);
// left-pad the lhs_shape with ones
lhs_shape.insert(begin(lhs_shape), max_rank - lhs_rank, 1);
// left-pad the rhs_shape with ones
rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
for (size_t index = 0; index < max_rank; ++index) {
size_t lhs_dim = lhs_shape.at(index);
size_t rhs_dim = rhs_shape.at(index);
if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) {
throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
}
string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1,
const Shape& shape2)
{
ostringstream os;
os << "Auto-broadcast not possible for these input shapes:"
<< " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
return os.str();
result.push_back(max(lhs_dim, rhs_dim));
}
return result;
};
pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes) {
Shape target_shape = accumulate(begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
vector<Shape> full_shapes;
for (const Shape& input : input_shapes) {
Shape padded_shape{input};
padded_shape.insert(begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(move(padded_shape));
}
return {target_shape, full_shapes};
}
static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values) {
vector<Shape> input_shapes;
for (const auto& input : values) {
input_shapes.push_back(input.get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
/// it should be a superset of it (containing it as a continuous subset). This
/// implies we may expand the number of axes of input node. The ranks of
/// source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation.
///
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node.
///
/// \return The broadcasted Node.
///
static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
const Shape& output_shape,
const Shape& source_shape) {
shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
// If node already has the required shape, return original node
if (output_shape == value.get_shape()) {
return broadcasted_node;
}
NGRAPH_CHECK(source_shape.size() == output_shape.size(),
"Ranks of source_shape and output_shape dont match: ",
source_shape.size(),
" vs ",
output_shape.size());
AxisVector broadcast_axes;
Shape squeezed_shape;
// Positions of axes which have length of 1 are needed to calculate broadcast_axes
// for nGraph broadcast operation. We need to remove ones from source shape
// to avoid broadcasting axis conflict.
for (size_t index = 0; index < output_shape.size(); ++index) {
if (source_shape.at(index) == 1 && output_shape.at(index) != 1) {
broadcast_axes.push_back(index);
} else {
squeezed_shape.push_back(source_shape.at(index));
}
}
///
/// \brief Calculate the output shape of numpy-style broadcast operation for two
/// shapes.
///
/// \note More info:
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
/// Example: left: [3, 1, 10] right: [5, 1] return: [3, 5, 10]
///
/// \param lhs_shape First input shape.
/// \param rhs_shape Second input Shape.
///
/// \return Broadcast shape of input shapes.
///
static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape)
{
Shape result;
auto lhs_rank = lhs_shape.size();
auto rhs_rank = rhs_shape.size();
auto max_rank = max(lhs_rank, rhs_rank);
if (squeezed_shape != value.get_shape()) {
broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
}
// left-pad the lhs_shape with ones
lhs_shape.insert(begin(lhs_shape), max_rank - lhs_rank, 1);
// left-pad the rhs_shape with ones
rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
if (!broadcast_axes.empty()) {
auto shape_const = op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
broadcasted_node =
make_shared<op::v1::Broadcast>(broadcasted_node,
shape_const,
opset1::get_axes_mapping_output(output_shape, broadcast_axes));
}
for (size_t index = 0; index < max_rank; ++index)
{
size_t lhs_dim = lhs_shape.at(index);
size_t rhs_dim = rhs_shape.at(index);
return broadcasted_node;
}
if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1)
{
throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
}
/// \brief Broadcast input node.
///
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] axis The start index to align with output_shape
///
/// \return The broadcasted Node.
///
static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value, const Shape& output_shape, int64_t axis) {
auto value_shape = value.get_shape();
result.push_back(max(lhs_dim, rhs_dim));
}
// If node already has the required shape, return original node
if (output_shape == value_shape) {
return value.get_node_shared_ptr();
}
return result;
};
if (axis == -1) {
axis = output_shape.size() - value_shape.size();
}
pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes)
{
Shape target_shape = accumulate(
begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
auto trimmed_value_shape = value_shape;
while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1) {
trimmed_value_shape.pop_back();
}
vector<Shape> full_shapes;
for (const Shape& input : input_shapes)
{
Shape padded_shape{input};
padded_shape.insert(
begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(move(padded_shape));
}
AxisSet axes;
for (int64_t i = 0; i < axis; ++i) {
axes.insert(static_cast<size_t>(i));
}
return {target_shape, full_shapes};
for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i) {
axes.insert(i);
}
auto trimmed_value = value;
if (value_shape != trimmed_value_shape) {
trimmed_value = builder::opset1::reshape(value, trimmed_value_shape);
}
auto shape_const = op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
auto value_bcast =
make_shared<op::v1::Broadcast>(trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
return move(value_bcast);
}
pair<shared_ptr<Node>, shared_ptr<Node>> numpy_broadcast(const pair<Output<Node>, Output<Node>>& args) {
NGRAPH_CHECK(args.first.get_node());
NGRAPH_CHECK(args.second.get_node());
const Shape& arg1_in_shape = args.first.get_shape();
const Shape& arg2_in_shape = args.second.get_shape();
// Handle the trivial case...
if (arg1_in_shape == arg2_in_shape) {
return make_pair(args.first.get_node_shared_ptr(), args.second.get_node_shared_ptr());
}
NodeVector bcasted_outputs = as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
}
OutputVector numpy_broadcast_outputs(const OutputVector& values) {
if (values.size() <= 1) {
return values;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(values);
OutputVector broadcasted_inputs;
for (size_t i = 0; i < values.size(); ++i) {
broadcasted_inputs.push_back(numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape) {
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
}
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left, const Output<Node>& right) {
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shapes(
{Shape{begin(left_shape), next(end(left_shape), -2)}, Shape{begin(right_shape), next(end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(end(left_output_shape), next(begin(left_shape), left_shape.size() - 2), end(left_shape));
right_output_shape.insert(end(right_output_shape),
next(begin(right_shape), right_shape.size() - 2),
end(right_shape));
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(end(left_full_shape), next(begin(left_shape), left_shape.size() - 2), end(left_shape));
right_full_shape.insert(end(right_full_shape), next(begin(right_shape), right_shape.size() - 2), end(right_shape));
return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
numpy_broadcast_node(right, right_output_shape, right_full_shape)};
}
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis) {
if (inputs.size() <= 1) {
return inputs;
}
OutputVector broadcasted_inputs{inputs[0]};
for (size_t i = 1; i < inputs.size(); ++i) {
broadcasted_inputs.push_back(broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
}
return broadcasted_inputs;
}
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
size_t start_match_axis) {
vector<size_t> axes(output_shape.size() - input_shape.size());
// Populate the axes vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range:
// [start_match_axis, start_match_axis + input_shape.size()]
iota(begin(axes), begin(axes) + start_match_axis, 0);
iota(begin(axes) + start_match_axis, end(axes), start_match_axis + input_shape.size());
auto axes_mapping = opset1::get_axes_mapping(output_shape, axes);
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
}
namespace opset1 {
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis) {
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical) {
return right;
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension) {
if (new_right_shape.at(dimension) == 1) {
new_right_shape.pop_back();
} else {
break;
}
}
static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values)
{
vector<Shape> input_shapes;
for (const auto& input : values)
{
input_shapes.push_back(input.get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
// Find first dimensions at front with length different from 1
size_t num_ones = 0;
for (size_t dimension : new_right_shape) {
if (dimension == 1) {
++num_ones;
} else {
break;
}
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
/// it should be a superset of it (containing it as a continuous subset). This
/// implies we may expand the number of axes of input node. The ranks of
/// source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation.
///
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node.
///
/// \return The broadcasted Node.
///
static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
const Shape& output_shape,
const Shape& source_shape)
{
shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
// If node already has the required shape, return original node
if (output_shape == value.get_shape())
{
return broadcasted_node;
}
// Remove dimensions with length=1 from front
new_right_shape.erase(begin(new_right_shape), next(begin(new_right_shape), num_ones));
NGRAPH_CHECK(source_shape.size() == output_shape.size(),
"Ranks of source_shape and output_shape dont match: ",
source_shape.size(),
" vs ",
output_shape.size());
auto reshape_right = reshape(right, new_right_shape);
AxisVector broadcast_axes;
Shape squeezed_shape;
// Positions of axes which have length of 1 are needed to calculate broadcast_axes
// for nGraph broadcast operation. We need to remove ones from source shape
// to avoid broadcasting axis conflict.
for (size_t index = 0; index < output_shape.size(); ++index)
{
if (source_shape.at(index) == 1 && output_shape.at(index) != 1)
{
broadcast_axes.push_back(index);
}
else
{
squeezed_shape.push_back(source_shape.at(index));
}
}
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
if (squeezed_shape != value.get_shape())
{
broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
}
return make_broadcast(reshape_right, left_shape, start_match_axis);
}
if (!broadcast_axes.empty())
{
auto shape_const =
op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
broadcasted_node = make_shared<op::v1::Broadcast>(
broadcasted_node,
shape_const,
opset1::get_axes_mapping_output(output_shape, broadcast_axes));
}
vector<size_t> get_axes_mapping(const Shape& output_shape, const AxisSet& broadcast_axes) {
NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
vector<size_t> axes_mapping(output_shape.size());
iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i) {
axes_mapping.erase(axes_mapping.begin() + *i);
}
return axes_mapping;
}
return broadcasted_node;
}
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const PartialShape& input_shape,
std::size_t start_match_axis) {
NGRAPH_CHECK((input_shape.rank().is_static() && output_shape.rank().is_static()),
"Tensor's rank has to be static.");
NGRAPH_CHECK(
(input_shape.rank().get_length() + static_cast<int64_t>(start_match_axis) <= output_shape.rank().get_length()),
"Unable to figure out axes mapping.");
/// \brief Broadcast input node.
///
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] axis The start index to align with output_shape
///
/// \return The broadcasted Node.
///
static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value,
const Shape& output_shape,
int64_t axis)
{
auto value_shape = value.get_shape();
vector<int64_t> mapping(input_shape.rank().get_length());
iota(begin(mapping), end(mapping), start_match_axis);
// If node already has the required shape, return original node
if (output_shape == value_shape)
{
return value.get_node_shared_ptr();
}
return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
}
if (axis == -1)
{
axis = output_shape.size() - value_shape.size();
}
Output<Node> get_axes_mapping_output(const Shape& output_shape, const AxisSet& broadcast_axes) {
vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
}
auto trimmed_value_shape = value_shape;
while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1)
{
trimmed_value_shape.pop_back();
}
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const Output<Node>& input_shape,
std::size_t start_match_axis) {
const auto one_node = opset7::Constant::create(element::i64, Shape{}, {1});
const auto zero_node = opset7::Constant::create(element::i64, Shape{}, {0});
const auto start_match_axis_node = opset7::Constant::create(element::i64, Shape{}, {start_match_axis});
const auto target_shape_rank_node =
builder::opset1::reshape(std::make_shared<opset7::ShapeOf>(input_shape), Shape{});
AxisSet axes;
for (int64_t i = 0; i < axis; ++i)
{
axes.insert(static_cast<size_t>(i));
}
const auto range_node = std::make_shared<opset7::Range>(zero_node, target_shape_rank_node, one_node, element::i64);
for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i)
{
axes.insert(i);
}
// workaround for GPU plugin type incompatibility
const auto range_node_converted =
std::make_shared<opset7::Convert>(range_node, start_match_axis_node->get_element_type());
// end of workaround
auto trimmed_value = value;
if (value_shape != trimmed_value_shape)
{
trimmed_value = builder::opset1::reshape(value, trimmed_value_shape);
}
const auto result = std::make_shared<opset7::Add>(range_node_converted, start_match_axis_node);
return result;
}
auto shape_const =
op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
auto value_bcast = make_shared<op::v1::Broadcast>(
trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, const AxisSet& broadcast_axes) {
return make_shared<op::v1::Broadcast>(node,
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, broadcast_axes));
}
return move(value_bcast);
}
Output<Node> make_broadcast(const Output<Node>& node, const Shape& target_shape, size_t start_match_axis) {
const auto node_shape = std::make_shared<opset7::ShapeOf>(node);
return make_shared<op::v1::Broadcast>(node,
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, node_shape, start_match_axis));
}
pair<shared_ptr<Node>, shared_ptr<Node>>
numpy_broadcast(const pair<Output<Node>, Output<Node>>& args)
{
NGRAPH_CHECK(args.first.get_node());
NGRAPH_CHECK(args.second.get_node());
const Shape& arg1_in_shape = args.first.get_shape();
const Shape& arg2_in_shape = args.second.get_shape();
// Handle the trivial case...
if (arg1_in_shape == arg2_in_shape)
{
return make_pair(args.first.get_node_shared_ptr(),
args.second.get_node_shared_ptr());
}
NodeVector bcasted_outputs =
as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
}
OutputVector numpy_broadcast_outputs(const OutputVector& values)
{
if (values.size() <= 1)
{
return values;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(values);
OutputVector broadcasted_inputs;
for (size_t i = 0; i < values.size(); ++i)
{
broadcasted_inputs.push_back(
numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape)
{
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
}
OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes =
get_numpy_broadcast_shapes({Shape{begin(left_shape), next(end(left_shape), -2)},
Shape{begin(right_shape), next(end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(end(left_output_shape),
next(begin(left_shape), left_shape.size() - 2),
end(left_shape));
right_output_shape.insert(end(right_output_shape),
next(begin(right_shape), right_shape.size() - 2),
end(right_shape));
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(end(left_full_shape),
next(begin(left_shape), left_shape.size() - 2),
end(left_shape));
right_full_shape.insert(end(right_full_shape),
next(begin(right_shape), right_shape.size() - 2),
end(right_shape));
return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
numpy_broadcast_node(right, right_output_shape, right_full_shape)};
}
OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis)
{
if (inputs.size() <= 1)
{
return inputs;
}
OutputVector broadcasted_inputs{inputs[0]};
for (size_t i = 1; i < inputs.size(); ++i)
{
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
}
return broadcasted_inputs;
}
std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
size_t start_match_axis)
{
vector<size_t> axes(output_shape.size() - input_shape.size());
// Populate the axes vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range:
// [start_match_axis, start_match_axis + input_shape.size()]
iota(begin(axes), begin(axes) + start_match_axis, 0);
iota(begin(axes) + start_match_axis, end(axes), start_match_axis + input_shape.size());
auto axes_mapping = opset1::get_axes_mapping(output_shape, axes);
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
}
namespace opset1
{
Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
{
return right;
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
{
if (new_right_shape.at(dimension) == 1)
{
new_right_shape.pop_back();
}
else
{
break;
}
}
// Find first dimensions at front with length different from 1
size_t num_ones = 0;
for (size_t dimension : new_right_shape)
{
if (dimension == 1)
{
++num_ones;
}
else
{
break;
}
}
// Remove dimensions with length=1 from front
new_right_shape.erase(begin(new_right_shape),
next(begin(new_right_shape), num_ones));
auto reshape_right = reshape(right, new_right_shape);
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
return make_broadcast(reshape_right, left_shape, start_match_axis);
}
vector<size_t> get_axes_mapping(const Shape& output_shape,
const AxisSet& broadcast_axes)
{
NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
vector<size_t> axes_mapping(output_shape.size());
iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
{
axes_mapping.erase(axes_mapping.begin() + *i);
}
return axes_mapping;
}
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const PartialShape& input_shape,
std::size_t start_match_axis)
{
NGRAPH_CHECK((input_shape.rank().is_static() && output_shape.rank().is_static()),
"Tensor's rank has to be static.");
NGRAPH_CHECK(
(input_shape.rank().get_length() + static_cast<int64_t>(start_match_axis) <=
output_shape.rank().get_length()),
"Unable to figure out axes mapping.");
vector<int64_t> mapping(input_shape.rank().get_length());
iota(begin(mapping), end(mapping), start_match_axis);
return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
}
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const AxisSet& broadcast_axes)
{
vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
}
Output<Node> get_axes_mapping_output(const PartialShape& output_shape,
const Output<Node>& input_shape,
std::size_t start_match_axis)
{
const auto one_node = opset7::Constant::create(element::i64, Shape{}, {1});
const auto zero_node = opset7::Constant::create(element::i64, Shape{}, {0});
const auto start_match_axis_node =
opset7::Constant::create(element::i64, Shape{}, {start_match_axis});
const auto target_shape_rank_node = builder::opset1::reshape(
std::make_shared<opset7::ShapeOf>(input_shape), Shape{});
const auto range_node = std::make_shared<opset7::Range>(
zero_node, target_shape_rank_node, one_node, element::i64);
// workaround for GPU plugin type incompatibility
const auto range_node_converted = std::make_shared<opset7::Convert>(
range_node, start_match_axis_node->get_element_type());
// end of workaround
const auto result =
std::make_shared<opset7::Add>(range_node_converted, start_match_axis_node);
return result;
}
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
const AxisSet& broadcast_axes)
{
return make_shared<op::v1::Broadcast>(
node,
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, broadcast_axes));
}
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
size_t start_match_axis)
{
const auto node_shape = std::make_shared<opset7::ShapeOf>(node);
return make_shared<op::v1::Broadcast>(
node,
op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, node_shape, start_match_axis));
}
} // namespace opset1
} // namespace builder
} // namespace ngraph
} // namespace opset1
} // namespace builder
} // namespace ngraph

View File

@ -4,91 +4,68 @@
#include "ngraph/builder/make_constant.hpp"
namespace ngraph
{
namespace builder
{
std::shared_ptr<Node>
make_constant_from_double(const element::Type& type, const Shape& shape, double num)
{
auto ceil_func = [](double x) { return ceil(x); };
namespace ngraph {
namespace builder {
std::shared_ptr<Node> make_constant_from_double(const element::Type& type, const Shape& shape, double num) {
auto ceil_func = [](double x) {
return ceil(x);
};
std::shared_ptr<ngraph::Node> result = nullptr;
switch (type)
{
case element::Type_t::i8:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int8_t>(num, ceil_func));
break;
}
case element::Type_t::i16:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int16_t>(num, ceil_func));
break;
}
case element::Type_t::i32:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int32_t>(num, ceil_func));
break;
}
case element::Type_t::i64:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<int64_t>(num, ceil_func));
break;
}
case element::Type_t::u8:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint8_t>(num, ceil_func));
break;
}
case element::Type_t::u16:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint16_t>(num, ceil_func));
break;
}
case element::Type_t::u32:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint32_t>(num, ceil_func));
break;
}
case element::Type_t::u64:
{
result = std::make_shared<ngraph::op::Constant>(
type, shape, double_to_int<uint64_t>(num, ceil_func));
break;
}
case element::Type_t::f16:
{
result = builder::make_constant(type, shape, static_cast<float16>(num));
break;
}
case element::Type_t::bf16:
{
result = builder::make_constant(type, shape, static_cast<bfloat16>(num));
break;
}
case element::Type_t::f32:
{
result = builder::make_constant(type, shape, static_cast<float>(num));
break;
}
case element::Type_t::f64:
{
result = builder::make_constant(type, shape, num);
break;
}
default:
throw std::runtime_error("Unsupported data type during make_constant_from_double");
break;
}
return result;
}
} // namespace builder
} // namespace ngraph
std::shared_ptr<ngraph::Node> result = nullptr;
switch (type) {
case element::Type_t::i8: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int8_t>(num, ceil_func));
break;
}
case element::Type_t::i16: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int16_t>(num, ceil_func));
break;
}
case element::Type_t::i32: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int32_t>(num, ceil_func));
break;
}
case element::Type_t::i64: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<int64_t>(num, ceil_func));
break;
}
case element::Type_t::u8: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint8_t>(num, ceil_func));
break;
}
case element::Type_t::u16: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint16_t>(num, ceil_func));
break;
}
case element::Type_t::u32: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint32_t>(num, ceil_func));
break;
}
case element::Type_t::u64: {
result = std::make_shared<ngraph::op::Constant>(type, shape, double_to_int<uint64_t>(num, ceil_func));
break;
}
case element::Type_t::f16: {
result = builder::make_constant(type, shape, static_cast<float16>(num));
break;
}
case element::Type_t::bf16: {
result = builder::make_constant(type, shape, static_cast<bfloat16>(num));
break;
}
case element::Type_t::f32: {
result = builder::make_constant(type, shape, static_cast<float>(num));
break;
}
case element::Type_t::f64: {
result = builder::make_constant(type, shape, num);
break;
}
default:
throw std::runtime_error("Unsupported data type during make_constant_from_double");
break;
}
return result;
}
} // namespace builder
} // namespace ngraph

View File

@ -3,6 +3,7 @@
//
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
@ -18,137 +19,111 @@
using namespace std;
namespace ngraph
{
namespace builder
{
namespace detail
{
namespace opset1
{
shared_ptr<Node> lp_norm(const Output<Node>& value,
size_t p_norm,
const Output<Node>& reduction_axes,
float bias,
bool keep_dims)
{
// In general "entrywise" lp-norm for matrix `A` is defined as following double
// sum:
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
shared_ptr<Node> abs_values{make_shared<ngraph::opset1::Abs>(value)};
shared_ptr<Node> p_node = ngraph::opset1::Constant::create(
value.get_element_type(), Shape{}, {p_norm});
namespace ngraph {
namespace builder {
namespace detail {
namespace opset1 {
shared_ptr<Node> lp_norm(const Output<Node>& value,
size_t p_norm,
const Output<Node>& reduction_axes,
float bias,
bool keep_dims) {
// In general "entrywise" lp-norm for matrix `A` is defined as following double
// sum:
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
shared_ptr<Node> abs_values{make_shared<ngraph::opset1::Abs>(value)};
shared_ptr<Node> p_node = ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {p_norm});
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)};
values =
make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)};
values = make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(
values->get_element_type(), Shape{}, {bias})};
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
values = make_shared<ngraph::opset1::Add>(values, bias_node);
values = make_shared<ngraph::opset1::Add>(values, bias_node);
// Get outer part of equation: raise values to 1/p_norm exponent.
shared_ptr<Node> inv_p_node = ngraph::opset1::Constant::create(
values->get_element_type(), Shape{}, {1.f / p_norm});
// Get outer part of equation: raise values to 1/p_norm exponent.
shared_ptr<Node> inv_p_node = ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {1.f / p_norm});
return {make_shared<ngraph::opset1::Power>(values, inv_p_node)
->add_provenance_group_members_above({value})};
}
} // namespace opset1
} // namespace detail
return {make_shared<ngraph::opset1::Power>(values, inv_p_node)->add_provenance_group_members_above({value})};
}
} // namespace opset1
} // namespace detail
shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims)
{
// L0 norm returns number of elements different from zero.
const shared_ptr<Node> zero_node{
ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {0.f})};
shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims) {
// L0 norm returns number of elements different from zero.
const shared_ptr<Node> zero_node{ngraph::opset1::Constant::create(value.get_element_type(), Shape{}, {0.f})};
// Convert bool values to input node data type.
const shared_ptr<Node> non_zero_values = make_shared<ngraph::opset1::Convert>(
make_shared<ngraph::opset1::NotEqual>(value, zero_node), value.get_element_type());
// Convert bool values to input node data type.
const shared_ptr<Node> non_zero_values =
make_shared<ngraph::opset1::Convert>(make_shared<ngraph::opset1::NotEqual>(value, zero_node),
value.get_element_type());
return make_shared<ngraph::opset1::ReduceSum>(
non_zero_values, reduction_axes, keep_dims)
->add_provenance_group_members_above({value});
}
return make_shared<ngraph::opset1::ReduceSum>(non_zero_values, reduction_axes, keep_dims)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias,
bool keep_dims)
{
const shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(
make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias,
bool keep_dims) {
const shared_ptr<Node> values{
make_shared<ngraph::opset1::ReduceSum>(make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
const shared_ptr<Node> bias_node{
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
const shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
return make_shared<ngraph::opset1::Add>(values, bias_node)
->add_provenance_group_members_above({value});
}
return make_shared<ngraph::opset1::Add>(values, bias_node)->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::l2_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias,
BiasMode bias_mode,
bool keep_dims)
{
shared_ptr<Node> pow = make_shared<ngraph::opset1::Power>(
value, make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2));
shared_ptr<Node> values{
make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
shared_ptr<Node> builder::opset1::l2_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias,
BiasMode bias_mode,
bool keep_dims) {
shared_ptr<Node> pow =
make_shared<ngraph::opset1::Power>(value,
make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2));
shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
shared_ptr<Node> bias_node{
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
shared_ptr<Node> result;
switch (bias_mode)
{
case BiasMode::MAX:
{
result = make_shared<ngraph::opset1::Sqrt>(
make_shared<ngraph::opset1::Maximum>(values, bias_node));
break;
}
case BiasMode::ADD:
default:
result = make_shared<ngraph::opset1::Sqrt>(
make_shared<ngraph::opset1::Add>(values, bias_node));
}
return result->add_provenance_group_members_above({value});
}
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
shared_ptr<Node> result;
switch (bias_mode) {
case BiasMode::MAX: {
result = make_shared<ngraph::opset1::Sqrt>(make_shared<ngraph::opset1::Maximum>(values, bias_node));
break;
}
case BiasMode::ADD:
default:
result = make_shared<ngraph::opset1::Sqrt>(make_shared<ngraph::opset1::Add>(values, bias_node));
}
return result->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
size_t p_norm,
float bias,
bool keep_dims)
{
// The number of non-zero elements
if (p_norm == 0)
{
return opset1::l0_norm(value, reduction_axes, keep_dims);
}
// sum of absolute values.
else if (p_norm == 1)
{
return opset1::l1_norm(value, reduction_axes, bias, keep_dims);
}
// sqrt of sum of squares - Euclidean norm
else if (p_norm == 2)
{
return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims);
}
// generic case
else
{
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims);
}
}
shared_ptr<Node> builder::opset1::lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
size_t p_norm,
float bias,
bool keep_dims) {
// The number of non-zero elements
if (p_norm == 0) {
return opset1::l0_norm(value, reduction_axes, keep_dims);
}
// sum of absolute values.
else if (p_norm == 1) {
return opset1::l1_norm(value, reduction_axes, bias, keep_dims);
}
// sqrt of sum of squares - Euclidean norm
else if (p_norm == 2) {
return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims);
}
// generic case
else {
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims);
}
}
} // namespace builder
} // namespace builder
} // namespace ngraph
} // namespace ngraph

View File

@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/builder/reduce_ops.hpp"
#include <numeric>
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
@ -16,133 +17,110 @@
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace builder
{
size_t get_num_elements(const Shape& shape, const AxisSet& reduction_axes)
{
size_t N = 1;
for (auto a : reduction_axes)
{
N *= shape[a];
}
return N;
}
namespace ngraph {
namespace builder {
size_t get_num_elements(const Shape& shape, const AxisSet& reduction_axes) {
size_t N = 1;
for (auto a : reduction_axes) {
N *= shape[a];
}
return N;
}
std::shared_ptr<Node> get_num_elements(const Output<Node>& value,
const Output<Node>& reduction_axes)
{
const auto value_shape = std::make_shared<ngraph::opset1::ShapeOf>(value);
const auto dim_values = std::make_shared<ngraph::opset1::Gather>(
value_shape,
reduction_axes,
ngraph::opset1::Constant::create(element::i64, {}, {0}));
std::shared_ptr<Node> get_num_elements(const Output<Node>& value, const Output<Node>& reduction_axes) {
const auto value_shape = std::make_shared<ngraph::opset1::ShapeOf>(value);
const auto dim_values =
std::make_shared<ngraph::opset1::Gather>(value_shape,
reduction_axes,
ngraph::opset1::Constant::create(element::i64, {}, {0}));
return std::make_shared<ngraph::opset1::ReduceProd>(
dim_values, ngraph::opset1::Constant::create(element::i64, {}, {0}));
}
return std::make_shared<ngraph::opset1::ReduceProd>(dim_values,
ngraph::opset1::Constant::create(element::i64, {}, {0}));
}
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
const AxisSet& reduction_axes,
bool keep_dims)
{
std::shared_ptr<Node> elems_number;
const auto value_elem_type = value.get_element_type();
const auto reduction_axes_const = ngraph::opset1::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector());
const auto value_elems_sum =
std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes_const, keep_dims);
if (value.get_partial_shape().is_static())
{
const auto elems_number_value = get_num_elements(value.get_shape(), reduction_axes);
elems_number = ngraph::opset1::Constant::create(
value_elem_type, Shape{}, {elems_number_value});
}
else
{
elems_number = get_num_elements(value, reduction_axes_const);
elems_number =
std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
}
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value, const AxisSet& reduction_axes, bool keep_dims) {
std::shared_ptr<Node> elems_number;
const auto value_elem_type = value.get_element_type();
const auto reduction_axes_const =
ngraph::opset1::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector());
const auto value_elems_sum = std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes_const, keep_dims);
if (value.get_partial_shape().is_static()) {
const auto elems_number_value = get_num_elements(value.get_shape(), reduction_axes);
elems_number = ngraph::opset1::Constant::create(value_elem_type, Shape{}, {elems_number_value});
} else {
elems_number = get_num_elements(value, reduction_axes_const);
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
}
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
->add_provenance_group_members_above({value});
}
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
->add_provenance_group_members_above({value});
}
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims)
{
std::shared_ptr<Node> elems_number;
const auto value_elem_type = value.get_element_type();
const auto value_elems_sum =
std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes, keep_dims);
elems_number = get_num_elements(value, reduction_axes);
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
std::shared_ptr<Node> builder::opset1::mean(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims) {
std::shared_ptr<Node> elems_number;
const auto value_elem_type = value.get_element_type();
const auto value_elems_sum = std::make_shared<ngraph::opset1::ReduceSum>(value, reduction_axes, keep_dims);
elems_number = get_num_elements(value, reduction_axes);
elems_number = std::make_shared<ngraph::opset1::Convert>(elems_number, value_elem_type);
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
->add_provenance_group_members_above({value});
}
return std::make_shared<ngraph::opset1::Divide>(value_elems_sum, elems_number)
->add_provenance_group_members_above({value});
}
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
const bool keep_dims = true;
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction) {
const bool keep_dims = true;
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
diff = std::make_shared<ngraph::opset1::ReduceSum>(
std::make_shared<ngraph::opset1::Multiply>(diff, diff),
ngraph::opset1::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()),
false);
diff = std::make_shared<ngraph::opset1::ReduceSum>(
std::make_shared<ngraph::opset1::Multiply>(diff, diff),
ngraph::opset1::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()),
false);
const auto& et = value.get_element_type();
const auto N = get_num_elements(value.get_shape(), reduction_axes);
const auto& et = value.get_element_type();
const auto N = get_num_elements(value.get_shape(), reduction_axes);
std::shared_ptr<Node> result;
if (bessel_correction)
{
const auto N1const = ngraph::opset1::Constant::create(et, Shape{}, {N - 1});
result = std::make_shared<ngraph::opset1::Divide>(diff, N1const);
}
else
{
const auto Nconst = ngraph::opset1::Constant::create(et, Shape{}, {N});
result = std::make_shared<ngraph::opset1::Divide>(diff, Nconst);
}
return result->add_provenance_group_members_above({value});
}
std::shared_ptr<Node> result;
if (bessel_correction) {
const auto N1const = ngraph::opset1::Constant::create(et, Shape{}, {N - 1});
result = std::make_shared<ngraph::opset1::Divide>(diff, N1const);
} else {
const auto Nconst = ngraph::opset1::Constant::create(et, Shape{}, {N});
result = std::make_shared<ngraph::opset1::Divide>(diff, Nconst);
}
return result->add_provenance_group_members_above({value});
}
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims,
bool bessel_correction)
{
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
std::shared_ptr<Node> builder::opset1::variance(const Output<Node>& value,
const Output<Node>& reduction_axes,
bool keep_dims,
bool bessel_correction) {
std::shared_ptr<Node> mu = opset1::mean(value, reduction_axes, keep_dims);
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
Output<Node> diff = std::make_shared<ngraph::opset1::Subtract>(value, mu);
diff = std::make_shared<ngraph::opset1::ReduceSum>(
std::make_shared<ngraph::opset1::Multiply>(diff, diff), reduction_axes, keep_dims);
diff = std::make_shared<ngraph::opset1::ReduceSum>(std::make_shared<ngraph::opset1::Multiply>(diff, diff),
reduction_axes,
keep_dims);
const auto& et = value.get_element_type();
auto N = get_num_elements(value, reduction_axes);
N = std::make_shared<ngraph::opset1::Convert>(N, et);
const auto& et = value.get_element_type();
auto N = get_num_elements(value, reduction_axes);
N = std::make_shared<ngraph::opset1::Convert>(N, et);
std::shared_ptr<Node> result;
if (bessel_correction)
{
const auto one = std::make_shared<ngraph::opset1::Constant>(et, Shape{}, 1);
N = std::make_shared<ngraph::opset1::Subtract>(N, one);
}
std::shared_ptr<Node> result;
if (bessel_correction) {
const auto one = std::make_shared<ngraph::opset1::Constant>(et, Shape{}, 1);
N = std::make_shared<ngraph::opset1::Subtract>(N, one);
}
result = std::make_shared<ngraph::opset1::Divide>(diff, N);
return result->add_provenance_group_members_above({value});
}
result = std::make_shared<ngraph::opset1::Divide>(diff, N);
return result->add_provenance_group_members_above({value});
}
} // namespace builder
} // namespace ngraph
} // namespace builder
} // namespace ngraph

View File

@ -2,13 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/builder/reshape.hpp"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include "ngraph/axis_vector.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reduce_prod.hpp"
@ -24,141 +25,113 @@
using namespace ngraph;
using namespace std;
shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape)
{
if (value.get_partial_shape().same_scheme(shape))
{
shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape) {
if (value.get_partial_shape().same_scheme(shape)) {
return value.get_node_shared_ptr();
}
else if (is_scalar(shape))
{
} else if (is_scalar(shape)) {
auto value_rank = value.get_shape().size();
AxisVector axes_vector(value_rank);
std::iota(axes_vector.begin(), axes_vector.end(), 0);
auto axes = op::Constant::create(element::i64, Shape{value_rank}, axes_vector);
return std::make_shared<op::Squeeze>(value, axes);
}
else
{
auto out_pattern = op::Constant::create(
element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
} else {
auto out_pattern =
op::Constant::create(element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
return make_shared<ngraph::opset1::Reshape>(value, out_pattern, false)
->add_provenance_group_members_above({value});
}
}
shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
{
const auto axes_order_const =
op::Constant::create(element::i64,
Shape{axes_order.size()},
vector<int64_t>(axes_order.begin(), axes_order.end()));
return make_shared<ngraph::opset1::Transpose>(value, axes_order_const)
->add_provenance_group_members_above({value});
shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order) {
const auto axes_order_const = op::Constant::create(element::i64,
Shape{axes_order.size()},
vector<int64_t>(axes_order.begin(), axes_order.end()));
return make_shared<ngraph::opset1::Transpose>(value, axes_order_const)->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value)
{
shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value) {
// This part is left to preserve backward compatibility and ensure passing ONNX tests.
if (value.get_partial_shape().is_static())
{
if (value.get_partial_shape().is_static()) {
vector<size_t> axes_order(value.get_shape().size());
iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order));
return builder::opset1::reorder_axes(value, axes_order);
}
const auto input_rank =
std::make_shared<ngraph::opset1::ShapeOf>(std::make_shared<ngraph::opset1::ShapeOf>(value));
const auto input_rank = std::make_shared<ngraph::opset1::ShapeOf>(std::make_shared<ngraph::opset1::ShapeOf>(value));
const auto neg_one = ngraph::opset1::Constant::create(element::i64, Shape{}, {-1});
const auto start_node = std::make_shared<ngraph::opset1::Add>(input_rank, neg_one);
const auto reverse_axes_order =
std::make_shared<ngraph::opset1::Range>(reshape(start_node, Shape{}), // start
neg_one, // stop (exclusive)
neg_one); // step
const auto reverse_axes_order = std::make_shared<ngraph::opset1::Range>(reshape(start_node, Shape{}), // start
neg_one, // stop (exclusive)
neg_one); // step
return std::make_shared<ngraph::opset1::Transpose>(value, reverse_axes_order)
->add_provenance_group_members_above({value});
}
namespace ngraph
{
namespace builder
{
namespace opset1
{
namespace
{
///
/// \brief Return the node representing normalized axis with respect to
/// provided rank.
///
/// \param[in] node_rank The node representing rank used for normalization.
/// \param[in] axis The axis value to be normalized.
///
/// \return The new Constant node representing normalized axis value.
///
std::shared_ptr<Node>
get_normalized_axis_node(const std::shared_ptr<Node> node_rank, int64_t axis)
{
auto axis_node =
ngraph::opset1::Constant::create(element::i64, Shape{1}, {axis});
// shortcut for alredy positive value
if (axis >= 0)
{
return axis_node;
}
namespace ngraph {
namespace builder {
namespace opset1 {
namespace {
///
/// \brief Return the node representing normalized axis with respect to
/// provided rank.
///
/// \param[in] node_rank The node representing rank used for normalization.
/// \param[in] axis The axis value to be normalized.
///
/// \return The new Constant node representing normalized axis value.
///
std::shared_ptr<Node> get_normalized_axis_node(const std::shared_ptr<Node> node_rank, int64_t axis) {
auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{1}, {axis});
// shortcut for alredy positive value
if (axis >= 0) {
return axis_node;
}
// TODO: What if axis value is beyond acceptable values? [-node_rank,
// node_rank-1]
return make_shared<ngraph::opset1::Add>(node_rank, axis_node);
}
} // namespace
} // namespace opset1
} // namespace builder
} // namespace ngraph
// TODO: What if axis value is beyond acceptable values? [-node_rank,
// node_rank-1]
return make_shared<ngraph::opset1::Add>(node_rank, axis_node);
}
} // namespace
} // namespace opset1
} // namespace builder
} // namespace ngraph
shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis)
{
shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis) {
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
// input tensor. The last dimension is the product of the rest of input tensor dimensions:
// [d_{axis}, ..., d_n]
shared_ptr<Node> output_shape;
if (axis == 0)
{
if (axis == 0) {
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {1, -1});
}
else if (axis == 1)
{
} else if (axis == 1) {
output_shape = ngraph::opset1::Constant::create(element::i64, Shape{2}, {0, -1});
}
else
{
} else {
const auto value_shape = make_shared<ngraph::opset1::ShapeOf>(value);
const auto value_rank = make_shared<ngraph::opset1::ShapeOf>(value_shape);
const auto axis_node = get_normalized_axis_node(value_rank, axis);
const auto first_part_dims = make_shared<ngraph::opset1::StridedSlice>(
value_shape,
ngraph::opset1::Constant::create(element::i64, {1}, {0}),
axis_node,
vector<int64_t>{},
vector<int64_t>{});
const auto first_part_dims_length = make_shared<ngraph::opset1::ReduceProd>(
first_part_dims, ngraph::opset1::Constant::create(element::i64, {}, {0}), true);
const auto first_part_dims =
make_shared<ngraph::opset1::StridedSlice>(value_shape,
ngraph::opset1::Constant::create(element::i64, {1}, {0}),
axis_node,
vector<int64_t>{},
vector<int64_t>{});
const auto first_part_dims_length =
make_shared<ngraph::opset1::ReduceProd>(first_part_dims,
ngraph::opset1::Constant::create(element::i64, {}, {0}),
true);
const auto remaining_part_length =
ngraph::opset1::Constant::create(element::i64, {1}, {-1});
const auto remaining_part_length = ngraph::opset1::Constant::create(element::i64, {1}, {-1});
output_shape = make_shared<ngraph::opset1::Concat>(
OutputVector{first_part_dims_length, remaining_part_length}, 0);
output_shape =
make_shared<ngraph::opset1::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
}
return make_shared<ngraph::opset1::Reshape>(value, output_shape, true)
->add_provenance_group_members_above({value});
return make_shared<ngraph::opset1::Reshape>(value, output_shape, true)->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis)
{
shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis) {
Shape output_shape(value.get_shape());
// Add empty axis at specified position.
auto empty_axis_it = begin(output_shape);
@ -167,40 +140,30 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
return builder::opset1::reshape(value, output_shape);
}
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes)
{
if (axes.empty())
{
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes) {
if (axes.empty()) {
return value.get_node_shared_ptr();
}
Shape in_shape{value.get_shape()};
for (size_t idx = 0; idx < axes.size(); ++idx)
{
for (size_t idx = 0; idx < axes.size(); ++idx) {
in_shape.at(axes.at(idx)) = 0;
}
Shape output_shape;
for (auto axis : in_shape)
{
if (axis != 0)
{
for (auto axis : in_shape) {
if (axis != 0) {
output_shape.push_back(axis);
}
}
return builder::opset1::reshape(value, output_shape);
}
shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value,
const size_t start_axis,
const size_t end_axis)
{
if (start_axis == end_axis)
{
shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value, const size_t start_axis, const size_t end_axis) {
if (start_axis == end_axis) {
return value.get_node_shared_ptr();
}
if (value.get_partial_shape().is_static())
{
if (value.get_partial_shape().is_static()) {
auto shape = value.get_shape();
// Multiply all alements of shape from start_axis to end_axis inclusive
size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis),
@ -220,22 +183,20 @@ shared_ptr<Node> builder::opset1::collapse(const Output<Node>& value,
// Split lengths used in VariadicSplit
const auto start_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {start_axis});
const auto end_axis_node = ngraph::opset1::Constant::create(element::i64, {1}, {end_axis + 1});
const auto collapsed_axis =
make_shared<ngraph::opset1::Subtract>(end_axis_node, start_axis_node);
const auto collapsed_axis = make_shared<ngraph::opset1::Subtract>(end_axis_node, start_axis_node);
const auto post_axis = make_shared<ngraph::opset1::Subtract>(rank, end_axis_node);
const auto split_lengths = make_shared<ngraph::opset1::Concat>(
OutputVector{start_axis_node, collapsed_axis, post_axis}, 0);
const auto split_lengths =
make_shared<ngraph::opset1::Concat>(OutputVector{start_axis_node, collapsed_axis, post_axis}, 0);
const auto split_axis = ngraph::opset1::Constant::create(element::i64, {}, {0});
const auto split_node =
make_shared<ngraph::opset1::VariadicSplit>(shape, split_axis, split_lengths);
const auto split_node = make_shared<ngraph::opset1::VariadicSplit>(shape, split_axis, split_lengths);
const auto reduced_axis = ngraph::opset1::Constant::create(element::i64, {1}, {0});
const auto collapsed_axis_size =
make_shared<ngraph::opset1::ReduceProd>(split_node->output(1), reduced_axis, true);
const auto collapsed_axis_size = make_shared<ngraph::opset1::ReduceProd>(split_node->output(1), reduced_axis, true);
const auto collapsed_shape = make_shared<ngraph::opset1::Concat>(
OutputVector{split_node->output(0), collapsed_axis_size, split_node->output(2)}, 0);
OutputVector{split_node->output(0), collapsed_axis_size, split_node->output(2)},
0);
return make_shared<ngraph::opset1::Reshape>(value, collapsed_shape, false);
}

View File

@ -3,25 +3,21 @@
//
#include "ngraph/builder/split.hpp"
#include "ngraph/opsets/opset1.hpp"
using namespace ngraph;
OutputVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis)
{
OutputVector builder::opset1::split(const Output<Node>& value, const std::vector<size_t>& split_lengths, int64_t axis) {
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
const auto variadic_split = std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return variadic_split->outputs();
}
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis) {
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);

View File

@ -12,549 +12,462 @@
#include "ngraph/type.hpp"
///
namespace ngraph
{
class AttributeVisitor;
namespace ngraph {
class AttributeVisitor;
/// \brief Provides access to an attribute of type AT as a value accessor type VAT
template <typename VAT>
class ValueAccessor;
/// \brief Provides access to an attribute of type AT as a value accessor type VAT
template <typename VAT>
class ValueAccessor;
/// \brief ValueAccessor<void> provides an accessor for values that do not have get/set methonds
/// via AttributeVistor.on_adapter.
///
/// All ValueAccessors must be derived from ValueAccessor<void> so that an AttributeVisitor
/// only needs to implement a subset of the on_adapter methods.
template <>
class NGRAPH_API ValueAccessor<void>
{
public:
/// \brief type info enables identification of the value accessor, as well as is_type and
/// as_type.
virtual const DiscreteTypeInfo& get_type_info() const = 0;
virtual ~ValueAccessor() {}
};
/// \brief ValueAccessor<void> provides an accessor for values that do not have get/set methonds
/// via AttributeVistor.on_adapter.
///
/// All ValueAccessors must be derived from ValueAccessor<void> so that an AttributeVisitor
/// only needs to implement a subset of the on_adapter methods.
template <>
class NGRAPH_API ValueAccessor<void> {
public:
/// \brief type info enables identification of the value accessor, as well as is_type and
/// as_type.
virtual const DiscreteTypeInfo& get_type_info() const = 0;
virtual ~ValueAccessor() {}
};
/// \brief Provides access to values via get/set methods from an m_value, typically from
/// ValueReference
///
/// The m_buffer holds a VAT, which may be wider than the attribute AT. For example, serializers
/// that only
/// support int64_t integers would use a ValueAccessor<vector<int64_t>> to reference a
/// vector<int8_t> attribute. Destruction moves the value back to the attribute if it was
/// changed.
/// \tparam VAT The adapter value type; may be wider than the value being accessed.
template <typename VAT>
class ValueAccessor : public ValueAccessor<void>
{
public:
/// Returns the value
virtual const VAT& get() = 0;
/// Sets the value
virtual void set(const VAT& value) = 0;
};
/// \brief Provides access to values via get/set methods from an m_value, typically from
/// ValueReference
///
/// The m_buffer holds a VAT, which may be wider than the attribute AT. For example, serializers
/// that only
/// support int64_t integers would use a ValueAccessor<vector<int64_t>> to reference a
/// vector<int8_t> attribute. Destruction moves the value back to the attribute if it was
/// changed.
/// \tparam VAT The adapter value type; may be wider than the value being accessed.
template <typename VAT>
class ValueAccessor : public ValueAccessor<void> {
public:
/// Returns the value
virtual const VAT& get() = 0;
/// Sets the value
virtual void set(const VAT& value) = 0;
};
template <>
class ValueAccessor<void*> : public ValueAccessor<void>
{
public:
virtual void* get_ptr() = 0;
virtual size_t size() = 0;
};
template <>
class ValueAccessor<void*> : public ValueAccessor<void> {
public:
virtual void* get_ptr() = 0;
virtual size_t size() = 0;
};
template <typename AT>
class DirectValueAccessor : public ValueAccessor<AT>
{
public:
DirectValueAccessor(AT& ref)
: m_ref(ref)
{
}
const AT& get() override { return m_ref; }
void set(const AT& value) override { m_ref = value; }
protected:
AT& m_ref;
};
template <typename AT, typename VAT>
class IndirectScalarValueAccessor : public ValueAccessor<VAT>
{
public:
IndirectScalarValueAccessor(AT& ref)
: m_ref(ref)
, m_buffer()
{
}
const VAT& get() override
{
if (!m_buffer_valid)
{
m_buffer = static_cast<VAT>(m_ref);
m_buffer_valid = true;
}
return m_buffer;
}
void set(const VAT& value) override
{
m_ref = static_cast<AT>(value);
m_buffer_valid = false;
}
protected:
AT& m_ref;
VAT m_buffer;
bool m_buffer_valid{false};
};
template <typename A, typename B>
A copy_from(B& b)
{
A result(b.size());
for (size_t i = 0; i < b.size(); ++i)
{
result[i] =
static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
}
return result;
template <typename AT>
class DirectValueAccessor : public ValueAccessor<AT> {
public:
DirectValueAccessor(AT& ref) : m_ref(ref) {}
const AT& get() override {
return m_ref;
}
void set(const AT& value) override {
m_ref = value;
}
template <typename AT, typename VAT>
class IndirectVectorValueAccessor : public ValueAccessor<VAT>
{
public:
IndirectVectorValueAccessor(AT& ref)
: m_ref(ref)
{
protected:
AT& m_ref;
};
template <typename AT, typename VAT>
class IndirectScalarValueAccessor : public ValueAccessor<VAT> {
public:
IndirectScalarValueAccessor(AT& ref) : m_ref(ref), m_buffer() {}
const VAT& get() override {
if (!m_buffer_valid) {
m_buffer = static_cast<VAT>(m_ref);
m_buffer_valid = true;
}
return m_buffer;
}
const VAT& get() override
{
if (!m_buffer_valid)
{
m_buffer = copy_from<typename std::remove_cv<VAT>::type>(m_ref);
m_buffer_valid = true;
}
return m_buffer;
void set(const VAT& value) override {
m_ref = static_cast<AT>(value);
m_buffer_valid = false;
}
protected:
AT& m_ref;
VAT m_buffer;
bool m_buffer_valid{false};
};
template <typename A, typename B>
A copy_from(B& b) {
A result(b.size());
for (size_t i = 0; i < b.size(); ++i) {
result[i] = static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
}
return result;
}
template <typename AT, typename VAT>
class IndirectVectorValueAccessor : public ValueAccessor<VAT> {
public:
IndirectVectorValueAccessor(AT& ref) : m_ref(ref) {}
const VAT& get() override {
if (!m_buffer_valid) {
m_buffer = copy_from<typename std::remove_cv<VAT>::type>(m_ref);
m_buffer_valid = true;
}
return m_buffer;
}
void set(const VAT& value) override
{
m_ref = copy_from<AT>(value);
m_buffer_valid = false;
}
void set(const VAT& value) override {
m_ref = copy_from<AT>(value);
m_buffer_valid = false;
}
operator AT&() { return m_ref; }
operator AT&() {
return m_ref;
}
protected:
AT& m_ref;
VAT m_buffer;
bool m_buffer_valid{false};
};
protected:
AT& m_ref;
VAT m_buffer;
bool m_buffer_valid{false};
};
/// \brief An AttributeAdapter "captures" an attribute as an AT& and makes it available as a
/// ValueAccessor<VAT>.
template <typename AT>
class AttributeAdapter
{
};
/// \brief An AttributeAdapter "captures" an attribute as an AT& and makes it available as a
/// ValueAccessor<VAT>.
template <typename AT>
class AttributeAdapter {};
/// \brief Access an enum via a string
/// \tparam AT The attribute type enum class
template <typename AT>
class EnumAttributeAdapterBase : public ValueAccessor<std::string>
{
public:
EnumAttributeAdapterBase(AT& value)
: m_ref(value)
{
}
/// \brief Access an enum via a string
/// \tparam AT The attribute type enum class
template <typename AT>
class EnumAttributeAdapterBase : public ValueAccessor<std::string> {
public:
EnumAttributeAdapterBase(AT& value) : m_ref(value) {}
const std::string& get() override { return as_string(m_ref); }
void set(const std::string& value) override { m_ref = as_enum<AT>(value); }
operator AT&() { return m_ref; }
const std::string& get() override {
return as_string(m_ref);
}
void set(const std::string& value) override {
m_ref = as_enum<AT>(value);
}
operator AT&() {
return m_ref;
}
protected:
AT& m_ref;
};
protected:
AT& m_ref;
};
/// Adapters will see visitor
class VisitorAdapter : public ValueAccessor<void>
{
public:
virtual bool visit_attributes(AttributeVisitor& visitor) = 0;
};
/// Adapters will see visitor
class VisitorAdapter : public ValueAccessor<void> {
public:
virtual bool visit_attributes(AttributeVisitor& visitor) = 0;
};
template <>
class NGRAPH_API AttributeAdapter<float> : public IndirectScalarValueAccessor<float, double>
{
public:
AttributeAdapter(float& value)
: IndirectScalarValueAccessor<float, double>(value)
{
}
template <>
class NGRAPH_API AttributeAdapter<float> : public IndirectScalarValueAccessor<float, double> {
public:
AttributeAdapter(float& value) : IndirectScalarValueAccessor<float, double>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<float>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<float>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a double as a double
template <>
class NGRAPH_API AttributeAdapter<double> : public DirectValueAccessor<double>
{
public:
AttributeAdapter(double& value)
: DirectValueAccessor<double>(value)
{
}
/// \brief Access a double as a double
template <>
class NGRAPH_API AttributeAdapter<double> : public DirectValueAccessor<double> {
public:
AttributeAdapter(double& value) : DirectValueAccessor<double>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<double>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<double>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a string as a string
template <>
class NGRAPH_API AttributeAdapter<std::string> : public DirectValueAccessor<std::string>
{
public:
AttributeAdapter(std::string& value)
: DirectValueAccessor<std::string>(value)
{
}
/// \brief Access a string as a string
template <>
class NGRAPH_API AttributeAdapter<std::string> : public DirectValueAccessor<std::string> {
public:
AttributeAdapter(std::string& value) : DirectValueAccessor<std::string>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<string>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<string>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a bool as a bool
template <>
class NGRAPH_API AttributeAdapter<bool> : public DirectValueAccessor<bool>
{
public:
AttributeAdapter(bool& value)
: DirectValueAccessor<bool>(value)
{
}
/// \brief Access a bool as a bool
template <>
class NGRAPH_API AttributeAdapter<bool> : public DirectValueAccessor<bool> {
public:
AttributeAdapter(bool& value) : DirectValueAccessor<bool>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<bool>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<bool>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access an int8_t and an int64_t
template <>
class NGRAPH_API AttributeAdapter<int8_t> : public IndirectScalarValueAccessor<int8_t, int64_t>
{
public:
AttributeAdapter(int8_t& value)
: IndirectScalarValueAccessor<int8_t, int64_t>(value)
{
}
/// \brief Access an int8_t and an int64_t
template <>
class NGRAPH_API AttributeAdapter<int8_t> : public IndirectScalarValueAccessor<int8_t, int64_t> {
public:
AttributeAdapter(int8_t& value) : IndirectScalarValueAccessor<int8_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int8_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int8_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access an int16_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<int16_t>
: public IndirectScalarValueAccessor<int16_t, int64_t>
{
public:
AttributeAdapter(int16_t& value)
: IndirectScalarValueAccessor<int16_t, int64_t>(value)
{
}
/// \brief Access an int16_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<int16_t> : public IndirectScalarValueAccessor<int16_t, int64_t> {
public:
AttributeAdapter(int16_t& value) : IndirectScalarValueAccessor<int16_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int16_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int16_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access an int32_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<int32_t>
: public IndirectScalarValueAccessor<int32_t, int64_t>
{
public:
AttributeAdapter(int32_t& value)
: IndirectScalarValueAccessor<int32_t, int64_t>(value)
{
}
/// \brief Access an int32_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<int32_t> : public IndirectScalarValueAccessor<int32_t, int64_t> {
public:
AttributeAdapter(int32_t& value) : IndirectScalarValueAccessor<int32_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int32_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int32_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access an int64_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<int64_t> : public DirectValueAccessor<int64_t>
{
public:
AttributeAdapter(int64_t& value)
: DirectValueAccessor<int64_t>(value)
{
}
/// \brief Access an int64_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<int64_t> : public DirectValueAccessor<int64_t> {
public:
AttributeAdapter(int64_t& value) : DirectValueAccessor<int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int64_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<int64_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a uint8_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint8_t>
: public IndirectScalarValueAccessor<uint8_t, int64_t>
{
public:
AttributeAdapter(uint8_t& value)
: IndirectScalarValueAccessor<uint8_t, int64_t>(value)
{
}
/// \brief Access a uint8_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint8_t> : public IndirectScalarValueAccessor<uint8_t, int64_t> {
public:
AttributeAdapter(uint8_t& value) : IndirectScalarValueAccessor<uint8_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint8_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint8_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a uint16_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint16_t>
: public IndirectScalarValueAccessor<uint16_t, int64_t>
{
public:
AttributeAdapter(uint16_t& value)
: IndirectScalarValueAccessor<uint16_t, int64_t>(value)
{
}
/// \brief Access a uint16_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint16_t> : public IndirectScalarValueAccessor<uint16_t, int64_t> {
public:
AttributeAdapter(uint16_t& value) : IndirectScalarValueAccessor<uint16_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint16_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint16_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a uint32_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint32_t>
: public IndirectScalarValueAccessor<uint32_t, int64_t>
{
public:
AttributeAdapter(uint32_t& value)
: IndirectScalarValueAccessor<uint32_t, int64_t>(value)
{
}
/// \brief Access a uint32_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint32_t> : public IndirectScalarValueAccessor<uint32_t, int64_t> {
public:
AttributeAdapter(uint32_t& value) : IndirectScalarValueAccessor<uint32_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint32_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint32_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a uint64_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint64_t>
: public IndirectScalarValueAccessor<uint64_t, int64_t>
{
public:
AttributeAdapter(uint64_t& value)
: IndirectScalarValueAccessor<uint64_t, int64_t>(value)
{
}
/// \brief Access a uint64_t as an int64_t
template <>
class NGRAPH_API AttributeAdapter<uint64_t> : public IndirectScalarValueAccessor<uint64_t, int64_t> {
public:
AttributeAdapter(uint64_t& value) : IndirectScalarValueAccessor<uint64_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint64_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<uint64_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
#ifdef __APPLE__
// size_t is one of the uint types on _WIN32
template <>
class NGRAPH_API AttributeAdapter<size_t> : public IndirectScalarValueAccessor<size_t, int64_t>
{
public:
AttributeAdapter(size_t& value)
: IndirectScalarValueAccessor<size_t, int64_t>(value)
{
}
// size_t is one of the uint types on _WIN32
template <>
class NGRAPH_API AttributeAdapter<size_t> : public IndirectScalarValueAccessor<size_t, int64_t> {
public:
AttributeAdapter(size_t& value) : IndirectScalarValueAccessor<size_t, int64_t>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
template <>
class NGRAPH_API AttributeAdapter<std::vector<size_t>>
: public IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>
{
public:
AttributeAdapter(std::vector<size_t>& value)
: IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>(value)
{
}
template <>
class NGRAPH_API AttributeAdapter<std::vector<size_t>>
: public IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>> {
public:
AttributeAdapter(std::vector<size_t>& value)
: IndirectVectorValueAccessor<std::vector<size_t>, std::vector<int64_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<size_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<size_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
#endif
/// Note: These class bodies cannot be defined with templates because of interactions
/// between dllexport and templates on Windows.
/// Note: These class bodies cannot be defined with templates because of interactions
/// between dllexport and templates on Windows.
/// \brief Access a vector<int8_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int8_t>>
: public DirectValueAccessor<std::vector<int8_t>>
{
public:
AttributeAdapter(std::vector<int8_t>& value)
: DirectValueAccessor<std::vector<int8_t>>(value)
{
}
/// \brief Access a vector<int8_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int8_t>> : public DirectValueAccessor<std::vector<int8_t>> {
public:
AttributeAdapter(std::vector<int8_t>& value) : DirectValueAccessor<std::vector<int8_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int8_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int8_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<int16_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int16_t>>
: public DirectValueAccessor<std::vector<int16_t>>
{
public:
AttributeAdapter(std::vector<int16_t>& value)
: DirectValueAccessor<std::vector<int16_t>>(value)
{
}
/// \brief Access a vector<int16_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int16_t>> : public DirectValueAccessor<std::vector<int16_t>> {
public:
AttributeAdapter(std::vector<int16_t>& value) : DirectValueAccessor<std::vector<int16_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int16_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int16_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<int32_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int32_t>>
: public DirectValueAccessor<std::vector<int32_t>>
{
public:
AttributeAdapter(std::vector<int32_t>& value)
: DirectValueAccessor<std::vector<int32_t>>(value)
{
}
/// \brief Access a vector<int32_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int32_t>> : public DirectValueAccessor<std::vector<int32_t>> {
public:
AttributeAdapter(std::vector<int32_t>& value) : DirectValueAccessor<std::vector<int32_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int32_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int32_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<int64_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int64_t>>
: public DirectValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(std::vector<int64_t>& value)
: DirectValueAccessor<std::vector<int64_t>>(value)
{
}
/// \brief Access a vector<int64_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<int64_t>> : public DirectValueAccessor<std::vector<int64_t>> {
public:
AttributeAdapter(std::vector<int64_t>& value) : DirectValueAccessor<std::vector<int64_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int64_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<int64_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<uint8_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint8_t>>
: public DirectValueAccessor<std::vector<uint8_t>>
{
public:
AttributeAdapter(std::vector<uint8_t>& value)
: DirectValueAccessor<std::vector<uint8_t>>(value)
{
}
/// \brief Access a vector<uint8_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint8_t>> : public DirectValueAccessor<std::vector<uint8_t>> {
public:
AttributeAdapter(std::vector<uint8_t>& value) : DirectValueAccessor<std::vector<uint8_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint8_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint8_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<uint16_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint16_t>>
: public DirectValueAccessor<std::vector<uint16_t>>
{
public:
AttributeAdapter(std::vector<uint16_t>& value)
: DirectValueAccessor<std::vector<uint16_t>>(value)
{
}
/// \brief Access a vector<uint16_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint16_t>> : public DirectValueAccessor<std::vector<uint16_t>> {
public:
AttributeAdapter(std::vector<uint16_t>& value) : DirectValueAccessor<std::vector<uint16_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint16_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint16_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<uint32_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint32_t>>
: public DirectValueAccessor<std::vector<uint32_t>>
{
public:
AttributeAdapter(std::vector<uint32_t>& value)
: DirectValueAccessor<std::vector<uint32_t>>(value)
{
}
/// \brief Access a vector<uint32_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint32_t>> : public DirectValueAccessor<std::vector<uint32_t>> {
public:
AttributeAdapter(std::vector<uint32_t>& value) : DirectValueAccessor<std::vector<uint32_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint32_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint32_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<uint64_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint64_t>>
: public DirectValueAccessor<std::vector<uint64_t>>
{
public:
AttributeAdapter(std::vector<uint64_t>& value)
: DirectValueAccessor<std::vector<uint64_t>>(value)
{
}
/// \brief Access a vector<uint64_t>
template <>
class NGRAPH_API AttributeAdapter<std::vector<uint64_t>> : public DirectValueAccessor<std::vector<uint64_t>> {
public:
AttributeAdapter(std::vector<uint64_t>& value) : DirectValueAccessor<std::vector<uint64_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint64_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<uint64_t>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<float>
template <>
class NGRAPH_API AttributeAdapter<std::vector<float>>
: public DirectValueAccessor<std::vector<float>>
{
public:
AttributeAdapter(std::vector<float>& value)
: DirectValueAccessor<std::vector<float>>(value)
{
}
/// \brief Access a vector<float>
template <>
class NGRAPH_API AttributeAdapter<std::vector<float>> : public DirectValueAccessor<std::vector<float>> {
public:
AttributeAdapter(std::vector<float>& value) : DirectValueAccessor<std::vector<float>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<double>
template <>
class NGRAPH_API AttributeAdapter<std::vector<double>>
: public DirectValueAccessor<std::vector<double>>
{
public:
AttributeAdapter(std::vector<double>& value)
: DirectValueAccessor<std::vector<double>>(value)
{
}
/// \brief Access a vector<double>
template <>
class NGRAPH_API AttributeAdapter<std::vector<double>> : public DirectValueAccessor<std::vector<double>> {
public:
AttributeAdapter(std::vector<double>& value) : DirectValueAccessor<std::vector<double>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<double>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<double>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
/// \brief Access a vector<string>
template <>
class NGRAPH_API AttributeAdapter<std::vector<std::string>>
: public DirectValueAccessor<std::vector<std::string>>
{
public:
AttributeAdapter(std::vector<std::string>& value)
: DirectValueAccessor<std::vector<std::string>>(value)
{
}
/// \brief Access a vector<string>
template <>
class NGRAPH_API AttributeAdapter<std::vector<std::string>> : public DirectValueAccessor<std::vector<std::string>> {
public:
AttributeAdapter(std::vector<std::string>& value) : DirectValueAccessor<std::vector<std::string>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
} // namespace ngraph

View File

@ -12,139 +12,125 @@
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
template <typename T>
class ValueAccessor;
class VisitorAdapter;
class Node;
class Function;
namespace ngraph {
template <typename T>
class ValueAccessor;
class VisitorAdapter;
class Node;
class Function;
/// \brief Visits the attributes of a node, primarily for serialization-like tasks.
/// \brief Visits the attributes of a node, primarily for serialization-like tasks.
///
/// Attributes are the node parameters that are always compile-time constants.
/// Values computed from the graph topology and attributes during compilation are not
/// attributes.
///
/// Attributes have a wide variety of types, but serialization formats are more restricted.
/// We asume serialation easily supports scalar types of bool 64-bit signed, string, and double,
/// and has specialized ways to support numeric arrays and raw data+size. The visitor and
/// adapter convert between the limited serialization types and the unlimited attribute types.
///
/// A visitor is passed to an op's visit_attributes method. The visit_attributes method calls
/// the template method visitor.on_attribute<AT>(const std::string& name, AT& value) on each
/// attribute. The visitor can read or write the attribute's value. The on_attribute
/// method creates an AttributeAdapter<AT> for the value and passes it to one of the visitors
/// on_adapter methods. The on_adapter methods expect a reference to a ValueAccessor<VAT> or a
/// VisitorAdapter. A ValueAccessor<VAT> has get/set methods that can be used to read/write the
/// attribute value as type VAT. These methods are triggered by deriving AttributeAdapter<AT>
/// from ValueAccessor<VAT>. For more complex cases, such as structs, the on_adapter method for
/// VisitorAdapter passes the name and visitor to the adapter, so that the adapter can perform
/// additional work such as visiting struct members or sequence values.
///
/// When a node visits an attribute with structure, the node's on_attribute passes a name for
/// the entire attribute, but the struct will have its own methods to be visited. Similarly, a
/// vector will have a sequence of members to be visited. The adapter may use the visitor
/// methods start_struct/finish_struct and start_vector/next_vector/finish_vector to inidicate
/// nexted members.
///
/// The visitor method get_name_with_context creates a generic nested version of the name.
/// Visitors can override according to their serialization requirements.
///
/// Attributes that are shared_ptr<Node> are special. They must have been already been
/// registered with the visitor using register_node, which needs a shared pointer to a node and
/// a string ID. The ID string will be used to serialize the node or find the node during
/// deserialization.
class NGRAPH_API AttributeVisitor {
public:
virtual ~AttributeVisitor() {}
// Must implement these methods
/// \brief handles all specialized on_adapter methods implemented by the visitor.
///
/// Attributes are the node parameters that are always compile-time constants.
/// Values computed from the graph topology and attributes during compilation are not
/// attributes.
///
/// Attributes have a wide variety of types, but serialization formats are more restricted.
/// We asume serialation easily supports scalar types of bool 64-bit signed, string, and double,
/// and has specialized ways to support numeric arrays and raw data+size. The visitor and
/// adapter convert between the limited serialization types and the unlimited attribute types.
///
/// A visitor is passed to an op's visit_attributes method. The visit_attributes method calls
/// the template method visitor.on_attribute<AT>(const std::string& name, AT& value) on each
/// attribute. The visitor can read or write the attribute's value. The on_attribute
/// method creates an AttributeAdapter<AT> for the value and passes it to one of the visitors
/// on_adapter methods. The on_adapter methods expect a reference to a ValueAccessor<VAT> or a
/// VisitorAdapter. A ValueAccessor<VAT> has get/set methods that can be used to read/write the
/// attribute value as type VAT. These methods are triggered by deriving AttributeAdapter<AT>
/// from ValueAccessor<VAT>. For more complex cases, such as structs, the on_adapter method for
/// VisitorAdapter passes the name and visitor to the adapter, so that the adapter can perform
/// additional work such as visiting struct members or sequence values.
///
/// When a node visits an attribute with structure, the node's on_attribute passes a name for
/// the entire attribute, but the struct will have its own methods to be visited. Similarly, a
/// vector will have a sequence of members to be visited. The adapter may use the visitor
/// methods start_struct/finish_struct and start_vector/next_vector/finish_vector to inidicate
/// nexted members.
///
/// The visitor method get_name_with_context creates a generic nested version of the name.
/// Visitors can override according to their serialization requirements.
///
/// Attributes that are shared_ptr<Node> are special. They must have been already been
/// registered with the visitor using register_node, which needs a shared pointer to a node and
/// a string ID. The ID string will be used to serialize the node or find the node during
/// deserialization.
class NGRAPH_API AttributeVisitor
{
public:
virtual ~AttributeVisitor() {}
// Must implement these methods
/// \brief handles all specialized on_adapter methods implemented by the visitor.
///
/// The adapter implements get_type_info(), which can be used to determine the adapter
/// directly
/// or via is_type and as_type on any platform
virtual void on_adapter(const std::string& name, ValueAccessor<void>& adapter) = 0;
// The remaining adapter methods fall back on the void adapter if not implemented
virtual void on_adapter(const std::string& name, ValueAccessor<void*>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<bool>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int8_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int16_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int32_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint8_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint16_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint32_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<float>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int8_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int16_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int32_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint8_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint16_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint32_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint64_t>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<float>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<double>>& adapter);
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter);
/// \brief Hook for adapters that need visitor access
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
/// The adapter implements get_type_info(), which can be used to determine the adapter
/// directly
/// or via is_type and as_type on any platform
virtual void on_adapter(const std::string& name, ValueAccessor<void>& adapter) = 0;
// The remaining adapter methods fall back on the void adapter if not implemented
virtual void on_adapter(const std::string& name, ValueAccessor<void*>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<bool>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int8_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int16_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int32_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint8_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint16_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint32_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<float>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int8_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int16_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int32_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<int64_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint8_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint16_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint32_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<uint64_t>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<double>>& adapter);
virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<std::string>>& adapter);
/// \brief Hook for adapters that need visitor access
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
/// \param name attribute name
/// \param adapter reference to a Function ValueAccessor<VAT>
virtual void on_adapter(const std::string& name,
ValueAccessor<std::shared_ptr<Function>>& adapter);
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
/// \param name attribute name
/// \param adapter reference to a Function ValueAccessor<VAT>
virtual void on_adapter(const std::string& name, ValueAccessor<std::shared_ptr<Function>>& adapter);
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
/// to a ValueAccessor<U> for one of the on_adpater methods.
template <typename AT>
void on_attribute(const std::string& name, AT& value)
{
AttributeAdapter<AT> adapter(value);
start_structure(name);
on_adapter(get_name_with_context(), adapter);
finish_structure();
}
/// \returns The nested context of visits
const std::vector<std::string>& get_context() const { return m_context; }
/// \returns context prepended to names
virtual std::string get_name_with_context();
/// \brief Start visiting a nested structure
virtual void start_structure(const std::string& name);
/// \brief Finish visiting a nested structure
virtual std::string finish_structure();
using node_id_t = std::string;
static const node_id_t invalid_node_id;
/// \brief Associate a node with an id.
///
/// No node may be used as an attribute unless it has already been registered with an ID.
/// References to nodes are visited with a ValueAccessor of their ID.
virtual void register_node(const std::shared_ptr<Node>& node,
node_id_t id = invalid_node_id);
/// Returns the node with the given id, or nullptr if there is no registered node
virtual std::shared_ptr<Node> get_registered_node(node_id_t id);
/// Returns the id for the node, or -1 if the node is not registered
virtual node_id_t get_registered_node_id(const std::shared_ptr<Node>& node);
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
/// to a ValueAccessor<U> for one of the on_adpater methods.
template <typename AT>
void on_attribute(const std::string& name, AT& value) {
AttributeAdapter<AT> adapter(value);
start_structure(name);
on_adapter(get_name_with_context(), adapter);
finish_structure();
}
/// \returns The nested context of visits
const std::vector<std::string>& get_context() const {
return m_context;
}
/// \returns context prepended to names
virtual std::string get_name_with_context();
/// \brief Start visiting a nested structure
virtual void start_structure(const std::string& name);
/// \brief Finish visiting a nested structure
virtual std::string finish_structure();
using node_id_t = std::string;
static const node_id_t invalid_node_id;
/// \brief Associate a node with an id.
///
/// No node may be used as an attribute unless it has already been registered with an ID.
/// References to nodes are visited with a ValueAccessor of their ID.
virtual void register_node(const std::shared_ptr<Node>& node, node_id_t id = invalid_node_id);
/// Returns the node with the given id, or nullptr if there is no registered node
virtual std::shared_ptr<Node> get_registered_node(node_id_t id);
/// Returns the id for the node, or -1 if the node is not registered
virtual node_id_t get_registered_node_id(const std::shared_ptr<Node>& node);
protected:
std::vector<std::string> m_context;
std::unordered_map<std::shared_ptr<Node>, node_id_t> m_node_id_map;
std::unordered_map<node_id_t, std::shared_ptr<Node>> m_id_node_map;
};
} // namespace ngraph
protected:
std::vector<std::string> m_context;
std::unordered_map<std::shared_ptr<Node>, node_id_t> m_node_id_map;
std::unordered_map<node_id_t, std::shared_ptr<Node>> m_id_node_map;
};
} // namespace ngraph

View File

@ -12,50 +12,48 @@
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
/// \brief A set of axes.
class AxisSet : public std::set<size_t>
{
public:
NGRAPH_API AxisSet();
namespace ngraph {
/// \brief A set of axes.
class AxisSet : public std::set<size_t> {
public:
NGRAPH_API AxisSet();
NGRAPH_API AxisSet(const std::initializer_list<size_t>& axes);
NGRAPH_API AxisSet(const std::initializer_list<size_t>& axes);
NGRAPH_API AxisSet(const std::set<size_t>& axes);
NGRAPH_API AxisSet(const std::set<size_t>& axes);
NGRAPH_API AxisSet(const std::vector<size_t>& axes);
NGRAPH_API AxisSet(const std::vector<size_t>& axes);
NGRAPH_API AxisSet(const AxisSet& axes);
NGRAPH_API AxisSet(const AxisSet& axes);
NGRAPH_API AxisSet& operator=(const AxisSet& v);
NGRAPH_API AxisSet& operator=(const AxisSet& v);
NGRAPH_API AxisSet& operator=(AxisSet&& v) noexcept;
NGRAPH_API AxisSet& operator=(AxisSet&& v) noexcept;
NGRAPH_API std::vector<int64_t> to_vector() const;
};
NGRAPH_API std::vector<int64_t> to_vector() const;
};
template <>
class NGRAPH_API AttributeAdapter<AxisSet> : public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(AxisSet& value)
: m_ref(value)
{
}
template <>
class NGRAPH_API AttributeAdapter<AxisSet> : public ValueAccessor<std::vector<int64_t>> {
public:
AttributeAdapter(AxisSet& value) : m_ref(value) {}
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
operator AxisSet&() { return m_ref; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
operator AxisSet&() {
return m_ref;
}
protected:
AxisSet& m_ref;
std::vector<int64_t> m_buffer;
bool m_buffer_valid{false};
};
protected:
AxisSet& m_ref;
std::vector<int64_t> m_buffer;
bool m_buffer_valid{false};
};
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
} // namespace ngraph

View File

@ -11,49 +11,41 @@
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
/// \brief A vector of axes.
class AxisVector : public std::vector<size_t>
{
public:
NGRAPH_API AxisVector(const std::initializer_list<size_t>& axes);
namespace ngraph {
/// \brief A vector of axes.
class AxisVector : public std::vector<size_t> {
public:
NGRAPH_API AxisVector(const std::initializer_list<size_t>& axes);
NGRAPH_API AxisVector(const std::vector<size_t>& axes);
NGRAPH_API AxisVector(const std::vector<size_t>& axes);
NGRAPH_API AxisVector(const AxisVector& axes);
NGRAPH_API AxisVector(const AxisVector& axes);
NGRAPH_API explicit AxisVector(size_t n);
NGRAPH_API explicit AxisVector(size_t n);
template <class InputIterator>
AxisVector(InputIterator first, InputIterator last)
: std::vector<size_t>(first, last)
{
}
template <class InputIterator>
AxisVector(InputIterator first, InputIterator last) : std::vector<size_t>(first, last) {}
NGRAPH_API AxisVector();
NGRAPH_API AxisVector();
NGRAPH_API ~AxisVector();
NGRAPH_API ~AxisVector();
NGRAPH_API AxisVector& operator=(const AxisVector& v);
NGRAPH_API AxisVector& operator=(const AxisVector& v);
NGRAPH_API AxisVector& operator=(AxisVector&& v) noexcept;
};
NGRAPH_API AxisVector& operator=(AxisVector&& v) noexcept;
};
template <>
class NGRAPH_API AttributeAdapter<AxisVector>
: public IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>
{
public:
AttributeAdapter(AxisVector& value)
: IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>(value)
{
}
template <>
class NGRAPH_API AttributeAdapter<AxisVector> : public IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>> {
public:
AttributeAdapter(AxisVector& value) : IndirectVectorValueAccessor<AxisVector, std::vector<int64_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisVector>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisVector>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const AxisVector& axis_vector);
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const AxisVector& axis_vector);
} // namespace ngraph

View File

@ -10,39 +10,33 @@
#include "ngraph/except.hpp"
namespace ngraph
{
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; }
template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args)
{
return write_all_to_stream(str << arg, args...);
}
namespace ngraph {
static inline std::ostream& write_all_to_stream(std::ostream& str) {
return str;
}
template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args) {
return write_all_to_stream(str << arg, args...);
}
struct CheckLocInfo
{
const char* file;
int line;
const char* check_string;
};
struct CheckLocInfo {
const char* file;
int line;
const char* check_string;
};
/// Base class for check failure exceptions.
class NGRAPH_API CheckFailure : public ngraph_error
{
public:
CheckFailure(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation))
{
}
/// Base class for check failure exceptions.
class NGRAPH_API CheckFailure : public ngraph_error {
public:
CheckFailure(const CheckLocInfo& check_loc_info, const std::string& context_info, const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation)) {}
private:
static std::string make_what(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation);
};
} // namespace ngraph
private:
static std::string make_what(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation);
};
} // namespace ngraph
//
// Helper macro for defining custom check macros, which throw custom exception classes and provide
@ -109,25 +103,20 @@ namespace ngraph
// TODO(amprocte): refactor NGRAPH_CHECK_HELPER so we don't have to introduce a locally-scoped
// variable (ss___) and risk shadowing.
//
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \
do \
{ \
if (!(check)) \
{ \
::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class( \
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \
do { \
if (!(check)) { \
::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \
} while (0)
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \
do \
{ \
if (!(check)) \
{ \
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \
} \
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \
do { \
if (!(check)) { \
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \
} \
} while (0)
/// \brief Macro to check whether a boolean condition holds.
@ -142,73 +131,46 @@ namespace ngraph
/// implemented with NGRAPH_CHECK macro.
/// \param ... Additional error message that should describe why that execution path is unreachable.
/// \throws ::ngraph::CheckFailure if the macro is executed.
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__)
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) \
CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__)
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
#define GLUE(x, y) x y
#define RETURN_ARG_COUNT(_1_, \
_2_, \
_3_, \
_4_, \
_5_, \
_6, \
_7, \
_8, \
_9, \
_10, \
_11, \
_12, \
_13, \
_14, \
_15, \
_16, \
_17, \
_18, \
_19, \
_20, \
_21, \
_22, \
_23, \
_24, \
_25, \
count, \
...) \
#define RETURN_ARG_COUNT(_1_, \
_2_, \
_3_, \
_4_, \
_5_, \
_6, \
_7, \
_8, \
_9, \
_10, \
_11, \
_12, \
_13, \
_14, \
_15, \
_16, \
_17, \
_18, \
_19, \
_20, \
_21, \
_22, \
_23, \
_24, \
_25, \
count, \
...) \
count
#define EXPAND_ARGS(args) RETURN_ARG_COUNT args
#define COUNT_ARGS_MAXN(...) \
EXPAND_ARGS((__VA_ARGS__, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
2, \
1, \
0))
#define COUNT_ARGS_MAXN(...) \
EXPAND_ARGS((__VA_ARGS__, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0))
#define OVERLOAD_MACRO2(name, count) name##count
#define OVERLOAD_MACRO1(name, count) OVERLOAD_MACRO2(name, count)
#define OVERLOAD_MACRO(name, count) OVERLOAD_MACRO1(name, count)
#define OVERLOAD_MACRO(name, count) OVERLOAD_MACRO1(name, count)
#define CALL_OVERLOAD(name, exc_class, ctx, ...) \
#define CALL_OVERLOAD(name, exc_class, ctx, ...) \
GLUE(OVERLOAD_MACRO(name, COUNT_ARGS_MAXN(__VA_ARGS__)), (exc_class, ctx, __VA_ARGS__))

View File

@ -11,50 +11,42 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
/// \brief Coordinates for a tensor element
class Coordinate : public std::vector<size_t>
{
public:
NGRAPH_API Coordinate();
NGRAPH_API Coordinate(const std::initializer_list<size_t>& axes);
namespace ngraph {
/// \brief Coordinates for a tensor element
class Coordinate : public std::vector<size_t> {
public:
NGRAPH_API Coordinate();
NGRAPH_API Coordinate(const std::initializer_list<size_t>& axes);
NGRAPH_API Coordinate(const Shape& shape);
NGRAPH_API Coordinate(const Shape& shape);
NGRAPH_API Coordinate(const std::vector<size_t>& axes);
NGRAPH_API Coordinate(const std::vector<size_t>& axes);
NGRAPH_API Coordinate(const Coordinate& axes);
NGRAPH_API Coordinate(const Coordinate& axes);
NGRAPH_API Coordinate(size_t n, size_t initial_value = 0);
NGRAPH_API Coordinate(size_t n, size_t initial_value = 0);
NGRAPH_API ~Coordinate();
NGRAPH_API ~Coordinate();
template <class InputIterator>
Coordinate(InputIterator first, InputIterator last)
: std::vector<size_t>(first, last)
{
}
template <class InputIterator>
Coordinate(InputIterator first, InputIterator last) : std::vector<size_t>(first, last) {}
NGRAPH_API Coordinate& operator=(const Coordinate& v);
NGRAPH_API Coordinate& operator=(const Coordinate& v);
NGRAPH_API Coordinate& operator=(Coordinate&& v) noexcept;
};
NGRAPH_API Coordinate& operator=(Coordinate&& v) noexcept;
};
template <>
class NGRAPH_API AttributeAdapter<Coordinate>
: public IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>
{
public:
AttributeAdapter(Coordinate& value)
: IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>(value)
{
}
template <>
class NGRAPH_API AttributeAdapter<Coordinate> : public IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>> {
public:
AttributeAdapter(Coordinate& value) : IndirectVectorValueAccessor<Coordinate, std::vector<int64_t>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate);
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate);
} // namespace ngraph

View File

@ -11,50 +11,45 @@
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
namespace ngraph {
/// \brief A difference (signed) of tensor element coordinates.
class CoordinateDiff : public std::vector<std::ptrdiff_t> {
public:
NGRAPH_API CoordinateDiff(const std::initializer_list<std::ptrdiff_t>& diffs);
NGRAPH_API CoordinateDiff(const std::vector<std::ptrdiff_t>& diffs);
NGRAPH_API CoordinateDiff(const CoordinateDiff& diffs);
NGRAPH_API explicit CoordinateDiff(size_t n, std::ptrdiff_t initial_value = 0);
template <class InputIterator>
CoordinateDiff(InputIterator first, InputIterator last) : std::vector<std::ptrdiff_t>(first, last) {}
NGRAPH_API ~CoordinateDiff();
NGRAPH_API CoordinateDiff();
NGRAPH_API CoordinateDiff& operator=(const CoordinateDiff& v);
NGRAPH_API CoordinateDiff& operator=(CoordinateDiff&& v) noexcept;
};
template <>
class NGRAPH_API AttributeAdapter<CoordinateDiff>
: public IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>
{
/// \brief A difference (signed) of tensor element coordinates.
class CoordinateDiff : public std::vector<std::ptrdiff_t>
{
public:
NGRAPH_API CoordinateDiff(const std::initializer_list<std::ptrdiff_t>& diffs);
public:
AttributeAdapter(CoordinateDiff& value)
: IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>(value) {}
NGRAPH_API CoordinateDiff(const std::vector<std::ptrdiff_t>& diffs);
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
NGRAPH_API CoordinateDiff(const CoordinateDiff& diffs);
NGRAPH_API explicit CoordinateDiff(size_t n, std::ptrdiff_t initial_value = 0);
template <class InputIterator>
CoordinateDiff(InputIterator first, InputIterator last)
: std::vector<std::ptrdiff_t>(first, last)
{
}
NGRAPH_API ~CoordinateDiff();
NGRAPH_API CoordinateDiff();
NGRAPH_API CoordinateDiff& operator=(const CoordinateDiff& v);
NGRAPH_API CoordinateDiff& operator=(CoordinateDiff&& v) noexcept;
};
template <>
class NGRAPH_API AttributeAdapter<CoordinateDiff>
: public IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>
{
public:
AttributeAdapter(CoordinateDiff& value)
: IndirectVectorValueAccessor<CoordinateDiff, std::vector<int64_t>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff);
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff);
} // namespace ngraph

View File

@ -15,41 +15,41 @@
//
#if defined(_WIN32)
#define NGRAPH_DEPRECATED(msg) __declspec(deprecated(msg))
# define NGRAPH_DEPRECATED(msg) __declspec(deprecated(msg))
#elif defined(__INTEL_COMPILER)
#define NGRAPH_DEPRECATED(msg) __attribute__((deprecated(msg)))
# define NGRAPH_DEPRECATED(msg) __attribute__((deprecated(msg)))
#elif defined(__GNUC__)
#define NGRAPH_DEPRECATED(msg) __attribute__((deprecated((msg))))
# define NGRAPH_DEPRECATED(msg) __attribute__((deprecated((msg))))
#else
#define NGRAPH_DEPRECATED(msg)
# define NGRAPH_DEPRECATED(msg)
#endif
// Suppress warning "-Wdeprecated-declarations" / C4996
#if defined(_MSC_VER)
#define NGRAPH_DO_PRAGMA(x) __pragma(x)
# define NGRAPH_DO_PRAGMA(x) __pragma(x)
#elif defined(__GNUC__)
#define NGRAPH_DO_PRAGMA(x) _Pragma(#x)
# define NGRAPH_DO_PRAGMA(x) _Pragma(# x)
#else
#define NGRAPH_DO_PRAGMA(x)
# define NGRAPH_DO_PRAGMA(x)
#endif
#if defined(_MSC_VER) && !defined(__clang__)
#define NGRAPH_SUPPRESS_DEPRECATED_START \
NGRAPH_DO_PRAGMA(warning(push)) \
NGRAPH_DO_PRAGMA(warning(disable : 4996))
#define NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_DO_PRAGMA(warning(pop))
# define NGRAPH_SUPPRESS_DEPRECATED_START \
NGRAPH_DO_PRAGMA(warning(push)) \
NGRAPH_DO_PRAGMA(warning(disable : 4996))
# define NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_DO_PRAGMA(warning(pop))
#elif defined(__INTEL_COMPILER)
#define NGRAPH_SUPPRESS_DEPRECATED_START \
NGRAPH_DO_PRAGMA(warning(push)) \
NGRAPH_DO_PRAGMA(warning(disable : 1478))
# define NGRAPH_SUPPRESS_DEPRECATED_START \
NGRAPH_DO_PRAGMA(warning(push)) \
NGRAPH_DO_PRAGMA(warning(disable : 1478))
NGRAPH_DO_PRAGMA(warning(disable : 1786))
#define NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_DO_PRAGMA(warning(pop))
# define NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_DO_PRAGMA(warning(pop))
#elif defined(__clang__) || ((__GNUC__) && (__GNUC__ * 100 + __GNUC_MINOR__ > 405))
#define NGRAPH_SUPPRESS_DEPRECATED_START \
NGRAPH_DO_PRAGMA(GCC diagnostic push) \
NGRAPH_DO_PRAGMA(GCC diagnostic ignored "-Wdeprecated-declarations")
#define NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_DO_PRAGMA(GCC diagnostic pop)
# define NGRAPH_SUPPRESS_DEPRECATED_START \
NGRAPH_DO_PRAGMA(GCC diagnostic push) \
NGRAPH_DO_PRAGMA(GCC diagnostic ignored "-Wdeprecated-declarations")
# define NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_DO_PRAGMA(GCC diagnostic pop)
#else
#define NGRAPH_SUPPRESS_DEPRECATED_START
#define NGRAPH_SUPPRESS_DEPRECATED_END
# define NGRAPH_SUPPRESS_DEPRECATED_START
# define NGRAPH_SUPPRESS_DEPRECATED_END
#endif

View File

@ -10,104 +10,119 @@
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/variant.hpp"
namespace ngraph
{
class Node;
namespace ngraph {
class Node;
namespace descriptor
{
class Output;
namespace descriptor {
class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class NGRAPH_API Input
{
friend class ngraph::Node;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class NGRAPH_API Input {
friend class ngraph::Node;
public:
/// \param node The node that owns this input
/// \param index The position of this this tensor in all input tensors
/// \param output The output that supplies a value for this input
Input(Node* node, size_t index, Output& output);
/// \brief Create an Input that is not connected to an output
/// \param node The node that owns this input
/// \param index The position of this this tensor in all input tensors
Input(Node* node, size_t index);
~Input();
public:
/// \param node The node that owns this input
/// \param index The position of this this tensor in all input tensors
/// \param output The output that supplies a value for this input
Input(Node* node, size_t index, Output& output);
/// \brief Create an Input that is not connected to an output
/// \param node The node that owns this input
/// \param index The position of this this tensor in all input tensors
Input(Node* node, size_t index);
~Input();
/// \return the node that this is an input of
std::shared_ptr<Node> get_node() const;
/// \return the node that this is an input of
std::shared_ptr<Node> get_node() const;
/// \return the raw pointer to the node that this is an input of
Node* get_raw_pointer_node() const { return m_node; }
/// \return the position within all supplied tensors of this input
size_t get_index() const { return m_index; }
/// \return the connected output
const Output& get_output() const { return *m_output; }
/// \return the connected output
Output& get_output() { return *m_output; }
/// \return true if an output is connected to the input.
bool has_output() const { return m_output != nullptr; }
/// \return the tensor of the connected output
const Tensor& get_tensor() const;
/// \return the raw pointer to the node that this is an input of
Node* get_raw_pointer_node() const {
return m_node;
}
/// \return the position within all supplied tensors of this input
size_t get_index() const {
return m_index;
}
/// \return the connected output
const Output& get_output() const {
return *m_output;
}
/// \return the connected output
Output& get_output() {
return *m_output;
}
/// \return true if an output is connected to the input.
bool has_output() const {
return m_output != nullptr;
}
/// \return the tensor of the connected output
const Tensor& get_tensor() const;
/// \return the tensor of the connected output
Tensor& get_tensor();
/// \return the tensor of the connected output
Tensor& get_tensor();
RTMap& get_rt_info() { return m_rt_info; }
const RTMap& get_rt_info() const { return m_rt_info; }
RTMap& get_rt_info() {
return m_rt_info;
}
const RTMap& get_rt_info() const {
return m_rt_info;
}
/// \brief Replace the current output that supplies a value for this input with output i
/// of node
void replace_output(std::shared_ptr<Node> node, size_t i);
/// \brief Replace the current output that supplies a value for this input with output
void replace_output(Output& output);
/// \brief Remove the output from this input. The node will not be valid until another
/// output is supplied.
void remove_output();
/// \brief Replace the current output that supplies a value for this input with output i
/// of node
void replace_output(std::shared_ptr<Node> node, size_t i);
/// \brief Replace the current output that supplies a value for this input with output
void replace_output(Output& output);
/// \brief Remove the output from this input. The node will not be valid until another
/// output is supplied.
void remove_output();
/// \return true if the value of this input is relevant to the output shapes of the
/// corresponding node. (Usually this is false.)
///
/// See Node::set_input_is_relevant_to_shape for more details.
bool get_is_relevant_to_shape() const { return m_is_relevant_to_shape; }
/// \return true if the value of this input is relevant to the output value of the
/// corresponding node. (Usually this is true.)
///
/// See Node::set_input_is_relevant_to_value for more details.
bool get_is_relevant_to_value() const { return m_is_relevant_to_value; }
/// \return true if the value of this input is relevant to the output shapes of the
/// corresponding node. (Usually this is false.)
///
/// See Node::set_input_is_relevant_to_shape for more details.
bool get_is_relevant_to_shape() const {
return m_is_relevant_to_shape;
}
/// \return true if the value of this input is relevant to the output value of the
/// corresponding node. (Usually this is true.)
///
/// See Node::set_input_is_relevant_to_value for more details.
bool get_is_relevant_to_value() const {
return m_is_relevant_to_value;
}
protected:
/// \return the tensor for the connected output
std::shared_ptr<const Tensor> get_tensor_ptr() const;
protected:
/// \return the tensor for the connected output
std::shared_ptr<const Tensor> get_tensor_ptr() const;
/// \return the tensor for the connected output
std::shared_ptr<Tensor> get_tensor_ptr();
/// \return the tensor for the connected output
std::shared_ptr<Tensor> get_tensor_ptr();
public:
/// \return the shape of the connected output
const Shape& get_shape() const;
public:
/// \return the shape of the connected output
const Shape& get_shape() const;
/// \return the partial shape of the connected output
const PartialShape& get_partial_shape() const;
/// \return the partial shape of the connected output
const PartialShape& get_partial_shape() const;
/// \return the element type of the connected output
const element::Type& get_element_type() const;
/// \return the element type of the connected output
const element::Type& get_element_type() const;
Input(const Input&) = default;
Input(Input&&) = default;
Input& operator=(const Input&) = default;
Input(const Input&) = default;
Input(Input&&) = default;
Input& operator=(const Input&) = default;
protected:
// owner of an argument node (in lieu of m_arguments)
std::shared_ptr<Node> m_src_node;
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
Output* m_output;
RTMap m_rt_info;
protected:
// owner of an argument node (in lieu of m_arguments)
std::shared_ptr<Node> m_src_node;
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
Output* m_output;
RTMap m_rt_info;
private:
bool m_is_relevant_to_shape;
bool m_is_relevant_to_value;
};
} // namespace descriptor
} // namespace ngraph
private:
bool m_is_relevant_to_shape;
bool m_is_relevant_to_value;
};
} // namespace descriptor
} // namespace ngraph

View File

@ -15,65 +15,68 @@
#include "ngraph/node_output.hpp"
#include "ngraph/variant.hpp"
namespace ngraph
{
// The forward declaration of Node is needed here because Node has a deque of
// Outputs, and Output is an incomplete type at this point. STL containers of
// incomplete type have undefined behavior according to the C++11 standard, and
// in practice including node.hpp here was causing compilation errors on some
// systems (namely macOS).
class Node;
namespace ngraph {
// The forward declaration of Node is needed here because Node has a deque of
// Outputs, and Output is an incomplete type at this point. STL containers of
// incomplete type have undefined behavior according to the C++11 standard, and
// in practice including node.hpp here was causing compilation errors on some
// systems (namely macOS).
class Node;
namespace descriptor
{
// Describes an output tensor of an op
class NGRAPH_API Output
{
public:
Output()
: m_node(nullptr)
, m_index(0)
, m_tensor(nullptr)
, m_inputs()
{
}
namespace descriptor {
// Describes an output tensor of an op
class NGRAPH_API Output {
public:
Output() : m_node(nullptr), m_index(0), m_tensor(nullptr), m_inputs() {}
/// \param node Node that owns this output.
/// \param index Position of the output tensor in all output tensors
/// \param tensor The tensor where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
/// \param node Node that owns this output.
/// \param index Position of the output tensor in all output tensors
/// \param tensor The tensor where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; }
ngraph::Output<Node> get_output() const;
std::shared_ptr<Tensor> get_tensor_ptr() const { return m_tensor; }
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
void add_input(Input* input);
void remove_input(Input* input);
const std::vector<Input*>& get_inputs() const { return m_inputs; }
Tensor& get_tensor() const;
std::shared_ptr<Node> get_node() const;
size_t get_index() const {
return m_index;
}
ngraph::Output<Node> get_output() const;
std::shared_ptr<Tensor> get_tensor_ptr() const {
return m_tensor;
}
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) {
m_tensor = tensor;
}
void add_input(Input* input);
void remove_input(Input* input);
const std::vector<Input*>& get_inputs() const {
return m_inputs;
}
Tensor& get_tensor() const;
RTMap& get_rt_info() { return m_rt_info; }
const RTMap& get_rt_info() const { return m_rt_info; }
/// \return the shape of the output
const Shape& get_shape() const;
RTMap& get_rt_info() {
return m_rt_info;
}
const RTMap& get_rt_info() const {
return m_rt_info;
}
/// \return the shape of the output
const Shape& get_shape() const;
/// \return the partial shape of the output
const PartialShape& get_partial_shape() const;
/// \return the partial shape of the output
const PartialShape& get_partial_shape() const;
/// \return the element type of the output
const element::Type& get_element_type() const;
/// \return the element type of the output
const element::Type& get_element_type() const;
Output(const Output&) = default;
Output(Output&&) = default;
Output& operator=(const Output&) = default;
Output(const Output&) = default;
Output(Output&&) = default;
Output& operator=(const Output&) = default;
protected:
Node* m_node;
size_t m_index;
std::shared_ptr<Tensor> m_tensor;
RTMap m_rt_info;
std::vector<Input*> m_inputs;
};
} // namespace descriptor
} // namespace ngraph
protected:
Node* m_node;
size_t m_index;
std::shared_ptr<Tensor> m_tensor;
RTMap m_rt_info;
std::vector<Input*> m_inputs;
};
} // namespace descriptor
} // namespace ngraph

View File

@ -14,91 +14,89 @@
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
class Node;
namespace ngraph {
class Node;
namespace runtime
{
class HostTensor;
namespace runtime {
class HostTensor;
}
using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
namespace descriptor {
/// \brief Compile-time descriptor of a first-class value that is a tensor.
class NGRAPH_API Tensor {
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
public:
Tensor(const element::Type& element_type, const PartialShape& pshape, const std::string& name);
Tensor(const element::Type& element_type, const PartialShape& pshape, Node* node, size_t node_output_number);
NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
const std::string& get_name() const;
NGRAPH_DEPRECATED("set_name() is deprecated! Please use set_names() instead.")
void set_name(const std::string& name);
const std::unordered_set<std::string>& get_names() const;
void set_names(const std::unordered_set<std::string>& names);
void add_names(const std::unordered_set<std::string>& names);
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
void set_element_type(const element::Type& elemenet_type);
void set_partial_shape(const PartialShape& partial_shape);
/// \brief sets lower bound value description
void set_lower_value(const HostTensorPtr& value);
/// \brief sets upper bound value description
void set_upper_value(const HostTensorPtr& value);
/// \brief unsets bound value descriptions
void invalidate_values();
const element::Type& get_element_type() const {
return m_element_type;
}
using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
namespace descriptor
{
/// \brief Compile-time descriptor of a first-class value that is a tensor.
class NGRAPH_API Tensor
{
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
const Shape& get_shape() const;
const PartialShape& get_partial_shape() const {
return m_partial_shape;
}
/// \brief gets lower bound value description
HostTensorPtr get_lower_value() const {
return m_lower_value;
}
/// \brief gets upper bound value description
HostTensorPtr get_upper_value() const {
return m_upper_value;
}
/// \brief checks if lower and upper bound are set and point to the same HostTensor
bool has_and_set_bound() const {
return m_upper_value != nullptr && m_upper_value == m_lower_value;
}
size_t size() const;
public:
Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name);
Tensor(const element::Type& element_type,
const PartialShape& pshape,
Node* node,
size_t node_output_number);
protected:
element::Type m_element_type;
NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
const std::string& get_name() const;
NGRAPH_DEPRECATED("set_name() is deprecated! Please use set_names() instead.")
void set_name(const std::string& name);
// TODO: remove along with get_shape
// Initially there was ngraph::Shape m_shape only available to keep shape information.
// Support for dynamic shapes required transition to ngraph::PartialShape.
// To smoothly transition to ngraph::PartialShape we introduced m_partial_shape
// and kept m_shape in sync with m_partial_shape. Synchronization point was placed
// in set_partial_shape which dramatically affected performance of ngraph::Function
// validation. Since we have started the transition to ngraph::PartialShape and reduced
// ngraph::Shape usage the only user of m_shape was get_shape method with signature:
// const Shape& descriptor::Tensor::get_shape() const
// It was decided to move m_shape and m_partial_shape synchronization point there and
// to keep methods signature backward compatible.
mutable std::mutex shape_mutex;
mutable std::atomic_bool m_shape_changed;
mutable Shape m_shape;
// TODO: end
const std::unordered_set<std::string>& get_names() const;
void set_names(const std::unordered_set<std::string>& names);
void add_names(const std::unordered_set<std::string>& names);
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
void set_element_type(const element::Type& elemenet_type);
void set_partial_shape(const PartialShape& partial_shape);
PartialShape m_partial_shape;
HostTensorPtr m_lower_value, m_upper_value;
std::string m_name;
std::unordered_set<std::string> m_names;
};
/// \brief sets lower bound value description
void set_lower_value(const HostTensorPtr& value);
/// \brief sets upper bound value description
void set_upper_value(const HostTensorPtr& value);
/// \brief unsets bound value descriptions
void invalidate_values();
const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const;
const PartialShape& get_partial_shape() const { return m_partial_shape; }
/// \brief gets lower bound value description
HostTensorPtr get_lower_value() const { return m_lower_value; }
/// \brief gets upper bound value description
HostTensorPtr get_upper_value() const { return m_upper_value; }
/// \brief checks if lower and upper bound are set and point to the same HostTensor
bool has_and_set_bound() const
{
return m_upper_value != nullptr && m_upper_value == m_lower_value;
}
size_t size() const;
protected:
element::Type m_element_type;
// TODO: remove along with get_shape
// Initially there was ngraph::Shape m_shape only available to keep shape information.
// Support for dynamic shapes required transition to ngraph::PartialShape.
// To smoothly transition to ngraph::PartialShape we introduced m_partial_shape
// and kept m_shape in sync with m_partial_shape. Synchronization point was placed
// in set_partial_shape which dramatically affected performance of ngraph::Function
// validation. Since we have started the transition to ngraph::PartialShape and reduced
// ngraph::Shape usage the only user of m_shape was get_shape method with signature:
// const Shape& descriptor::Tensor::get_shape() const
// It was decided to move m_shape and m_partial_shape synchronization point there and
// to keep methods signature backward compatible.
mutable std::mutex shape_mutex;
mutable std::atomic_bool m_shape_changed;
mutable Shape m_shape;
// TODO: end
PartialShape m_partial_shape;
HostTensorPtr m_lower_value, m_upper_value;
std::string m_name;
std::unordered_set<std::string> m_names;
};
NGRAPH_API
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
} // namespace descriptor
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
} // namespace descriptor
} // namespace ngraph

View File

@ -4,163 +4,171 @@
#pragma once
#include <limits>
#include <stddef.h>
#include <limits>
#include <stdexcept>
#include "ngraph/deprecated.hpp"
#include "ngraph/interval.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
namespace ngraph {
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
///
/// Static dimensions may be implicitly converted from value_type. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic().
class NGRAPH_API Dimension {
public:
using value_type = int64_t;
/// \brief Construct a static dimension.
/// \param dimension Value of the dimension.
Dimension(value_type dimension);
/// \brief Construct a dynamic dimension with bounded range
/// \param min_dimension The lower inclusive limit for the dimension
/// \param mas_dimension The upper inclusive limit for the dimension
Dimension(value_type min_dimension, value_type max_dimension);
/// \brief Construct a dynamic dimension with range [0, ...]
Dimension() = default;
bool operator==(const Dimension& dimension) const {
return m_dimension == dimension.m_dimension;
}
bool operator!=(const Dimension& dimension) const {
return m_dimension != dimension.m_dimension;
}
/// \brief Check whether this dimension is static.
/// \return `true` if the dimension is static, else `false`.
bool is_static() const {
return m_dimension.size() == 1;
}
/// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const {
return m_dimension.size() != 1;
}
/// \brief Convert this dimension to `value_type`. This dimension must be static and
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
value_type get_length() const;
value_type get_min_length() const;
value_type get_max_length() const;
/// \brief Return the interval of valid lengths
const Interval& get_interval() const {
return m_dimension;
}
Interval& get_interval() {
return m_dimension;
}
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to.
/// \return `true` if this dimension and `dim` are both dynamic, or if they are both
/// static and equal; otherwise, `false`.
bool same_scheme(const Dimension& dim) const;
/// \brief Try to merge two Dimension objects together.
/// \param[out] dst Reference to write the merged Dimension into.
/// \param d1 First dimension to merge.
/// \param d2 Second dimension to merge.
/// \return `true` if merging succeeds, else `false`.
///
/// Static dimensions may be implicitly converted from value_type. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic().
class NGRAPH_API Dimension
{
public:
using value_type = int64_t;
/// \li If `d1` is dynamic, writes `d2` to `dst` and returns `true`.
/// \li If `d2` is dynamic, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are static and equal, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are both static and unequal, leaves `dst` unchanged and
/// returns `false`.
static bool merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Construct a static dimension.
/// \param dimension Value of the dimension.
Dimension(value_type dimension);
/// \brief Try to merge two Dimension objects together with implicit broadcasting
/// of unit-sized dimension to non unit-sized dimension
static bool broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Construct a dynamic dimension with bounded range
/// \param min_dimension The lower inclusive limit for the dimension
/// \param mas_dimension The upper inclusive limit for the dimension
Dimension(value_type min_dimension, value_type max_dimension);
/// \brief Construct a dynamic dimension with range [0, ...]
Dimension() = default;
bool operator==(const Dimension& dimension) const
{
return m_dimension == dimension.m_dimension;
}
bool operator!=(const Dimension& dimension) const
{
return m_dimension != dimension.m_dimension;
}
/// \brief Check whether this dimension is static.
/// \return `true` if the dimension is static, else `false`.
bool is_static() const { return m_dimension.size() == 1; }
/// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return m_dimension.size() != 1; }
/// \brief Convert this dimension to `value_type`. This dimension must be static and
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
value_type get_length() const;
value_type get_min_length() const;
value_type get_max_length() const;
/// \brief Return the interval of valid lengths
const Interval& get_interval() const { return m_dimension; }
Interval& get_interval() { return m_dimension; }
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to.
/// \return `true` if this dimension and `dim` are both dynamic, or if they are both
/// static and equal; otherwise, `false`.
bool same_scheme(const Dimension& dim) const;
/// \brief Try to merge two Dimension objects together.
/// \param[out] dst Reference to write the merged Dimension into.
/// \param d1 First dimension to merge.
/// \param d2 Second dimension to merge.
/// \return `true` if merging succeeds, else `false`.
///
/// \li If `d1` is dynamic, writes `d2` to `dst` and returns `true`.
/// \li If `d2` is dynamic, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are static and equal, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are both static and unequal, leaves `dst` unchanged and
/// returns `false`.
static bool merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Try to merge two Dimension objects together with implicit broadcasting
/// of unit-sized dimension to non unit-sized dimension
static bool broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Check whether this dimension is capable of being merged with the argument
/// dimension.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension is compatible with `d`, else `false`.
///
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const;
/// \brief Check whether this dimension is a relaxation of the argument.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension relaxes `d`, else `false`.
///
/// A dimension `d1` _relaxes_ (or _is a relaxation of_) `d2` if `d1` and `d2` are static
/// and equal, or `d1` is dynamic.
///
/// `d1.relaxes(d2)` is equivalent to `d2.refines(d1)`.
bool relaxes(const Dimension& d) const;
/// \brief Check whether this dimension is a refinement of the argument.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension relaxes `d`, else `false`.
///
/// A dimension `d2` _refines_ (or _is a refinement of_) `d1` if `d1` and `d2` are static
/// and equal, or `d2` is dynamic.
///
/// `d1.refines(d2)` is equivalent to `d2.relaxes(d1)`.
bool refines(const Dimension& d) const;
/// \brief Create a dynamic dimension.
/// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); }
/// \brief Addition operator for Dimension.
/// \param dim Right operand for addition.
/// \return Smallest interval dimension enclosing inputs
Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction.
/// \return Smallest interval dimension enclosing inputs
Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton.
/// \return Smallest interval containing all "produces" which are 0 if either of `this` or
/// `dim` has length `0`, else unbounded if either is unbounded, else product of lengths.
Dimension operator*(const Dimension& dim) const;
/// \brief Add-into operator for Dimension.
/// \param dim Right operand for addition.
/// \return A reference to `*this`, after updating `*this` to the value `*this + dim`.
Dimension& operator+=(const Dimension& dim) { return (*this = *this + dim); }
/// \brief Multiply-into operator for Dimension.
/// \param dim Right operand for multiplication.
/// \return A reference to `*this`, after updating `*this` to the value `*this * dim`.
Dimension& operator*=(const Dimension& dim) { return (*this = *this * dim); }
/// \brief Intersection of dimensions
Dimension operator&(const Dimension& dim) const;
/// \brief Intersection of dimensions
Dimension& operator&=(const Dimension& dim);
private:
Dimension(const Interval& interval)
: m_dimension(interval)
{
}
// The actual numerical value of the dimension.
Interval m_dimension{};
};
/// \brief Insert a human-readable representation of a dimension into an output stream.
/// \param str The output stream targeted for insertion.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
/// \brief Check whether this dimension is capable of being merged with the argument
/// dimension.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension is compatible with `d`, else `false`.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`.
NGRAPH_API
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
} // namespace ngraph
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const;
/// \brief Check whether this dimension is a relaxation of the argument.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension relaxes `d`, else `false`.
///
/// A dimension `d1` _relaxes_ (or _is a relaxation of_) `d2` if `d1` and `d2` are static
/// and equal, or `d1` is dynamic.
///
/// `d1.relaxes(d2)` is equivalent to `d2.refines(d1)`.
bool relaxes(const Dimension& d) const;
/// \brief Check whether this dimension is a refinement of the argument.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension relaxes `d`, else `false`.
///
/// A dimension `d2` _refines_ (or _is a refinement of_) `d1` if `d1` and `d2` are static
/// and equal, or `d2` is dynamic.
///
/// `d1.refines(d2)` is equivalent to `d2.relaxes(d1)`.
bool refines(const Dimension& d) const;
/// \brief Create a dynamic dimension.
/// \return A dynamic dimension.
static Dimension dynamic() {
return Dimension();
}
/// \brief Addition operator for Dimension.
/// \param dim Right operand for addition.
/// \return Smallest interval dimension enclosing inputs
Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction.
/// \return Smallest interval dimension enclosing inputs
Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton.
/// \return Smallest interval containing all "produces" which are 0 if either of `this` or
/// `dim` has length `0`, else unbounded if either is unbounded, else product of lengths.
Dimension operator*(const Dimension& dim) const;
/// \brief Add-into operator for Dimension.
/// \param dim Right operand for addition.
/// \return A reference to `*this`, after updating `*this` to the value `*this + dim`.
Dimension& operator+=(const Dimension& dim) {
return (*this = *this + dim);
}
/// \brief Multiply-into operator for Dimension.
/// \param dim Right operand for multiplication.
/// \return A reference to `*this`, after updating `*this` to the value `*this * dim`.
Dimension& operator*=(const Dimension& dim) {
return (*this = *this * dim);
}
/// \brief Intersection of dimensions
Dimension operator&(const Dimension& dim) const;
/// \brief Intersection of dimensions
Dimension& operator&=(const Dimension& dim);
private:
Dimension(const Interval& interval) : m_dimension(interval) {}
// The actual numerical value of the dimension.
Interval m_dimension{};
};
/// \brief Insert a human-readable representation of a dimension into an output stream.
/// \param str The output stream targeted for insertion.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`.
NGRAPH_API
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
} // namespace ngraph

View File

@ -12,33 +12,27 @@
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace reduction
{
enum class Type
{
SUM,
PROD,
MIN,
MAX,
};
namespace ngraph {
namespace reduction {
enum class Type {
SUM,
PROD,
MIN,
MAX,
};
NGRAPH_API
std::ostream& operator<<(std::ostream& out, const Type& obj);
} // namespace reduction
NGRAPH_API
std::ostream& operator<<(std::ostream& out, const Type& obj);
} // namespace reduction
template <>
class NGRAPH_API AttributeAdapter<reduction::Type>
: public EnumAttributeAdapterBase<reduction::Type>
{
public:
AttributeAdapter(reduction::Type& value)
: EnumAttributeAdapterBase<reduction::Type>(value)
{
}
template <>
class NGRAPH_API AttributeAdapter<reduction::Type> : public EnumAttributeAdapterBase<reduction::Type> {
public:
AttributeAdapter(reduction::Type& value) : EnumAttributeAdapterBase<reduction::Type>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
} // namespace ngraph

View File

@ -10,75 +10,61 @@
#include "ngraph/check.hpp"
namespace ngraph
{
/// Uses a pairings defined by EnumTypes::get() to convert between strings
/// and enum values.
template <typename EnumType>
class EnumNames
{
public:
/// Converts strings to enum values
static EnumType as_enum(const std::string& name)
{
auto to_lower = [](const std::string& s) {
std::string rc = s;
std::transform(rc.begin(), rc.end(), rc.begin(), [](char c) {
return static_cast<char>(::tolower(static_cast<int>(c)));
});
return rc;
};
for (const auto& p : get().m_string_enums)
{
if (to_lower(p.first) == to_lower(name))
{
return p.second;
}
namespace ngraph {
/// Uses a pairings defined by EnumTypes::get() to convert between strings
/// and enum values.
template <typename EnumType>
class EnumNames {
public:
/// Converts strings to enum values
static EnumType as_enum(const std::string& name) {
auto to_lower = [](const std::string& s) {
std::string rc = s;
std::transform(rc.begin(), rc.end(), rc.begin(), [](char c) {
return static_cast<char>(::tolower(static_cast<int>(c)));
});
return rc;
};
for (const auto& p : get().m_string_enums) {
if (to_lower(p.first) == to_lower(name)) {
return p.second;
}
NGRAPH_CHECK(false, "\"", name, "\"", " is not a member of enum ", get().m_enum_name);
}
/// Converts enum values to strings
static const std::string& as_string(EnumType e)
{
for (const auto& p : get().m_string_enums)
{
if (p.second == e)
{
return p.first;
}
}
NGRAPH_CHECK(false, " invalid member of enum ", get().m_enum_name);
}
private:
/// Creates the mapping.
EnumNames(const std::string& enum_name,
const std::vector<std::pair<std::string, EnumType>> string_enums)
: m_enum_name(enum_name)
, m_string_enums(string_enums)
{
}
/// Must be defined to returns a singleton for each supported enum class
static EnumNames<EnumType>& get();
const std::string m_enum_name;
std::vector<std::pair<std::string, EnumType>> m_string_enums;
};
/// Returns the enum value matching the string
template <typename Type, typename Value>
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type
as_enum(const Value& value)
{
return EnumNames<Type>::as_enum(value);
NGRAPH_CHECK(false, "\"", name, "\"", " is not a member of enum ", get().m_enum_name);
}
/// Returns the string matching the enum value
template <typename Value>
const std::string& as_string(Value value)
{
return EnumNames<Value>::as_string(value);
/// Converts enum values to strings
static const std::string& as_string(EnumType e) {
for (const auto& p : get().m_string_enums) {
if (p.second == e) {
return p.first;
}
}
NGRAPH_CHECK(false, " invalid member of enum ", get().m_enum_name);
}
} // namespace ngraph
private:
/// Creates the mapping.
EnumNames(const std::string& enum_name, const std::vector<std::pair<std::string, EnumType>> string_enums)
: m_enum_name(enum_name),
m_string_enums(string_enums) {}
/// Must be defined to returns a singleton for each supported enum class
static EnumNames<EnumType>& get();
const std::string m_enum_name;
std::vector<std::pair<std::string, EnumType>> m_string_enums;
};
/// Returns the enum value matching the string
template <typename Type, typename Value>
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type as_enum(const Value& value) {
return EnumNames<Type>::as_enum(value);
}
/// Returns the string matching the enum value
template <typename Value>
const std::string& as_string(Value value) {
return EnumNames<Value>::as_string(value);
}
} // namespace ngraph

View File

@ -5,35 +5,33 @@
#pragma once
#include <cstdint>
#include <ngraph/ngraph_visibility.hpp>
#include <string>
#include <ngraph/ngraph_visibility.hpp>
namespace ngraph {
/// \brief Get the names environment variable as a string.
/// \param env_var The string name of the environment variable to get.
/// \return Returns string by value or an empty string if the environment
/// variable is not set.
NGRAPH_API
std::string getenv_string(const char* env_var);
namespace ngraph
{
/// \brief Get the names environment variable as a string.
/// \param env_var The string name of the environment variable to get.
/// \return Returns string by value or an empty string if the environment
/// variable is not set.
NGRAPH_API
std::string getenv_string(const char* env_var);
/// \brief Get the names environment variable as an integer. If the value is not a
/// valid integer then an exception is thrown.
/// \param env_var The string name of the environment variable to get.
/// \param default_value The value to return if the environment variable is not set.
/// \return Returns value or default_value if the environment variable is not set.
NGRAPH_API
int32_t getenv_int(const char* env_var, int32_t default_value = -1);
/// \brief Get the names environment variable as an integer. If the value is not a
/// valid integer then an exception is thrown.
/// \param env_var The string name of the environment variable to get.
/// \param default_value The value to return if the environment variable is not set.
/// \return Returns value or default_value if the environment variable is not set.
NGRAPH_API
int32_t getenv_int(const char* env_var, int32_t default_value = -1);
/// \brief Get the names environment variable as a boolean. If the value is not a
/// valid boolean then an exception is thrown. Valid booleans are one of
/// 1, 0, on, off, true, false
/// All values are case insensitive.
/// If the environment variable is not set the default_value is returned.
/// \param env_var The string name of the environment variable to get.
/// \param default_value The value to return if the environment variable is not set.
/// \return Returns the boolean value of the environment variable.
NGRAPH_API
bool getenv_bool(const char* env_var, bool default_value = false);
} // namespace ngraph
/// \brief Get the names environment variable as a boolean. If the value is not a
/// valid boolean then an exception is thrown. Valid booleans are one of
/// 1, 0, on, off, true, false
/// All values are case insensitive.
/// If the environment variable is not set the default_value is returned.
/// \param env_var The string name of the environment variable to get.
/// \param default_value The value to return if the environment variable is not set.
/// \return Returns the boolean value of the environment variable.
NGRAPH_API
bool getenv_bool(const char* env_var, bool default_value = false);
} // namespace ngraph

View File

@ -12,192 +12,165 @@
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type_traits.hpp"
namespace ngraph
{
/// \brief Execute handlers on a subgraph to compute values
namespace ngraph {
/// \brief Execute handlers on a subgraph to compute values
///
///
template <typename V>
class Evaluator {
public:
/// \brief values we compute for outputs
using value_map = std::map<RawNodeOutput, V>;
/// \brief Handler for a computation of a value about an op
///
/// A handler is passed a Node* and a vector of computed input values. The handler should
/// return a vector of computed output values.
using op_handler = std::function<std::vector<V>(Node* op, std::vector<V>& inputs)>;
/// \brief Table of ops with handlers
using op_handler_map = std::map<Node::type_info_t, op_handler>;
/// \brief construct handler using the provided op handlers.
///
template <typename V>
class Evaluator
{
public:
/// \brief values we compute for outputs
using value_map = std::map<RawNodeOutput, V>;
/// Evaluations share previously computed values so that calls on multiple nodes can share
/// work. All state is kept in the value map, which is accessible for clearing or seeding
/// with
/// Evaluator::get_value_map().
///
/// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
Evaluator(const op_handler_map& handlers, value_map& values) : m_handlers(handlers), m_value_map(values) {}
/// \brief Handler for a computation of a value about an op
///
/// A handler is passed a Node* and a vector of computed input values. The handler should
/// return a vector of computed output values.
using op_handler = std::function<std::vector<V>(Node* op, std::vector<V>& inputs)>;
/// \brief Retrieves the value_map, which holds all Output<Node> value associations.
value_map& get_value_map() {
return m_value_map;
}
const value_map& get_value_map() const {
return m_value_map;
}
/// \brief If set, handles all ops
const op_handler& get_univeral_handler() const {
return m_universal_handler;
}
/// \brief If set, handles all ops not in the handlers
const op_handler& get_default_handler() const {
return m_default_handler;
}
/// \brief If set, handles all ops
void set_univeral_handler(const op_handler& handler) {
m_universal_handler = handler;
}
/// \brief If set, handles all ops not in the handlers
void set_default_handler(const op_handler& handler) {
m_default_handler = handler;
}
/// \brief Table of ops with handlers
using op_handler_map = std::map<Node::type_info_t, op_handler>;
/// \brief construct handler using the provided op handlers.
///
/// Evaluations share previously computed values so that calls on multiple nodes can share
/// work. All state is kept in the value map, which is accessible for clearing or seeding
/// with
/// Evaluator::get_value_map().
///
/// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
Evaluator(const op_handler_map& handlers, value_map& values)
: m_handlers(handlers)
, m_value_map(values)
{
protected:
op_handler get_handler(Node* node) {
op_handler handler = m_universal_handler;
if (!handler) {
auto it = m_handlers.find(node->get_type_info());
if (it == m_handlers.end()) {
handler = m_default_handler;
} else {
handler = it->second;
}
}
return handler;
}
/// \brief Retrieves the value_map, which holds all Output<Node> value associations.
value_map& get_value_map() { return m_value_map; }
const value_map& get_value_map() const { return m_value_map; }
/// \brief If set, handles all ops
const op_handler& get_univeral_handler() const { return m_universal_handler; }
/// \brief If set, handles all ops not in the handlers
const op_handler& get_default_handler() const { return m_default_handler; }
/// \brief If set, handles all ops
void set_univeral_handler(const op_handler& handler) { m_universal_handler = handler; }
/// \brief If set, handles all ops not in the handlers
void set_default_handler(const op_handler& handler) { m_default_handler = handler; }
class Inst;
using InstPtr = std::unique_ptr<Inst>;
using InstStack = std::stack<InstPtr>;
/// \brief Intstructions for evaluations state machine
class Inst {
protected:
op_handler get_handler(Node* node)
{
op_handler handler = m_universal_handler;
if (!handler)
{
auto it = m_handlers.find(node->get_type_info());
if (it == m_handlers.end())
{
handler = m_default_handler;
}
else
{
handler = it->second;
}
}
return handler;
}
class Inst;
using InstPtr = std::unique_ptr<Inst>;
using InstStack = std::stack<InstPtr>;
/// \brief Intstructions for evaluations state machine
class Inst
{
protected:
Inst(Node* node)
: m_node(node)
{
}
public:
virtual ~Inst() {}
virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
Node* get_node() { return m_node; }
protected:
Node* m_node;
};
/// \brief Ensure value has been analyzed
class ValueInst : public Inst
{
public:
ValueInst(const Output<Node>& value)
: Inst(value.get_node())
, m_index(value.get_index())
{
}
ValueInst(const RawNodeOutput& value)
: Inst(value.node)
, m_index(value.index)
{
}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
{
// Request to analyze this value if we can
if (auto handler = evaluator.get_handler(node))
{
// Ensure the inputs are processed and then execute the op handler
inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
for (auto v : node->input_values())
{
inst_stack.push(InstPtr(new ValueInst(v)));
}
}
else
{
// We don't know how to handle this op, so mark the outputs as unknown
for (auto output : node->outputs())
{
evaluator.get_value_map()[output] = V();
}
}
}
private:
int64_t m_index;
};
/// \brief All arguments have been handled; execute the node handler
class ExecuteInst : public Inst
{
public:
ExecuteInst(Node* node, op_handler& handler)
: Inst(node)
, m_handler(handler)
{
}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
{
// Request to execute the handleer. Pass what we know about the inputs to the
// handler and associate the results with the outputs
std::vector<V> inputs;
for (auto v : node->input_values())
{
inputs.push_back(evaluator.get_value_map().at(v));
}
std::vector<V> outputs = m_handler(node, inputs);
for (size_t i = 0; i < outputs.size(); ++i)
{
evaluator.get_value_map()[node->output(i)] = outputs[i];
}
}
private:
op_handler m_handler;
};
Inst(Node* node) : m_node(node) {}
public:
/// \brief Determine information about value
V evaluate(const Output<Node>& value)
{
InstStack inst_stack;
inst_stack.push(InstPtr(new ValueInst(value)));
while (!inst_stack.empty())
{
InstPtr inst;
std::swap(inst_stack.top(), inst);
inst_stack.pop();
auto node = inst->get_node();
if (m_value_map.find(node->output(0)) != m_value_map.end())
{
// Already computed
continue;
}
inst->handle(*this, inst_stack, node);
}
return m_value_map.at(value);
virtual ~Inst() {}
virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
Node* get_node() {
return m_node;
}
protected:
op_handler m_universal_handler;
op_handler_map m_handlers;
op_handler m_default_handler;
value_map& m_value_map;
Node* m_node;
};
} // namespace ngraph
/// \brief Ensure value has been analyzed
class ValueInst : public Inst {
public:
ValueInst(const Output<Node>& value) : Inst(value.get_node()), m_index(value.get_index()) {}
ValueInst(const RawNodeOutput& value) : Inst(value.node), m_index(value.index) {}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
// Request to analyze this value if we can
if (auto handler = evaluator.get_handler(node)) {
// Ensure the inputs are processed and then execute the op handler
inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
for (auto v : node->input_values()) {
inst_stack.push(InstPtr(new ValueInst(v)));
}
} else {
// We don't know how to handle this op, so mark the outputs as unknown
for (auto output : node->outputs()) {
evaluator.get_value_map()[output] = V();
}
}
}
private:
int64_t m_index;
};
/// \brief All arguments have been handled; execute the node handler
class ExecuteInst : public Inst {
public:
ExecuteInst(Node* node, op_handler& handler) : Inst(node), m_handler(handler) {}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
// Request to execute the handleer. Pass what we know about the inputs to the
// handler and associate the results with the outputs
std::vector<V> inputs;
for (auto v : node->input_values()) {
inputs.push_back(evaluator.get_value_map().at(v));
}
std::vector<V> outputs = m_handler(node, inputs);
for (size_t i = 0; i < outputs.size(); ++i) {
evaluator.get_value_map()[node->output(i)] = outputs[i];
}
}
private:
op_handler m_handler;
};
public:
/// \brief Determine information about value
V evaluate(const Output<Node>& value) {
InstStack inst_stack;
inst_stack.push(InstPtr(new ValueInst(value)));
while (!inst_stack.empty()) {
InstPtr inst;
std::swap(inst_stack.top(), inst);
inst_stack.pop();
auto node = inst->get_node();
if (m_value_map.find(node->output(0)) != m_value_map.end()) {
// Already computed
continue;
}
inst->handle(*this, inst_stack, node);
}
return m_value_map.at(value);
}
protected:
op_handler m_universal_handler;
op_handler_map m_handlers;
op_handler m_default_handler;
value_map& m_value_map;
};
} // namespace ngraph

View File

@ -4,39 +4,23 @@
#pragma once
#include <ngraph/ngraph_visibility.hpp>
#include <sstream>
#include <stdexcept>
#include <ngraph/ngraph_visibility.hpp>
namespace ngraph {
/// Base error for ngraph runtime errors.
class NGRAPH_API ngraph_error : public std::runtime_error {
public:
explicit ngraph_error(const std::string& what_arg) : std::runtime_error(what_arg) {}
namespace ngraph
{
/// Base error for ngraph runtime errors.
class NGRAPH_API ngraph_error : public std::runtime_error
{
public:
explicit ngraph_error(const std::string& what_arg)
: std::runtime_error(what_arg)
{
}
explicit ngraph_error(const char* what_arg) : std::runtime_error(what_arg) {}
explicit ngraph_error(const char* what_arg)
: std::runtime_error(what_arg)
{
}
explicit ngraph_error(const std::stringstream& what_arg) : std::runtime_error(what_arg.str()) {}
};
explicit ngraph_error(const std::stringstream& what_arg)
: std::runtime_error(what_arg.str())
{
}
};
class NGRAPH_API unsupported_op : public std::runtime_error
{
public:
unsupported_op(const std::string& what_arg)
: std::runtime_error(what_arg)
{
}
};
} // namespace ngraph
class NGRAPH_API unsupported_op : public std::runtime_error {
public:
unsupported_op(const std::string& what_arg) : std::runtime_error(what_arg) {}
};
} // namespace ngraph

View File

@ -10,76 +10,68 @@
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
NGRAPH_API std::mutex& get_registry_mutex();
namespace ngraph {
NGRAPH_API std::mutex& get_registry_mutex();
/// \brief Registry of factories that can construct objects derived from BASE_TYPE
template <typename BASE_TYPE>
class FactoryRegistry
{
public:
using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
/// \brief Registry of factories that can construct objects derived from BASE_TYPE
template <typename BASE_TYPE>
class FactoryRegistry {
public:
using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
template <typename DERIVED_TYPE>
static Factory get_default_factory()
{
return []() { return new DERIVED_TYPE(); };
}
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
template <typename DERIVED_TYPE>
static Factory get_default_factory() {
return []() {
return new DERIVED_TYPE();
};
}
/// \brief Register a custom factory for type_info
void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[type_info] = factory;
}
/// \brief Register a custom factory for type_info
void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory) {
std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[type_info] = factory;
}
/// \brief Register a custom factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory(Factory factory)
{
register_factory(DERIVED_TYPE::type_info, factory);
}
/// \brief Register a custom factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory(Factory factory) {
register_factory(DERIVED_TYPE::type_info, factory);
}
/// \brief Register the defualt constructor factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory()
{
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
}
/// \brief Register the defualt constructor factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory() {
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
}
/// \brief Check to see if a factory is registered
bool has_factory(const typename BASE_TYPE::type_info_t& info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end();
}
/// \brief Check to see if a factory is registered
bool has_factory(const typename BASE_TYPE::type_info_t& info) {
std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end();
}
/// \brief Check to see if DERIVED_TYPE has a registered factory
template <typename DERIVED_TYPE>
bool has_factory()
{
return has_factory(DERIVED_TYPE::type_info);
}
/// \brief Check to see if DERIVED_TYPE has a registered factory
template <typename DERIVED_TYPE>
bool has_factory() {
return has_factory(DERIVED_TYPE::type_info);
}
/// \brief Create an instance for type_info
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info);
return it == m_factory_map.end() ? nullptr : it->second();
}
/// \brief Create an instance for type_info
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const {
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info);
return it == m_factory_map.end() ? nullptr : it->second();
}
/// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
BASE_TYPE* create() const
{
return create(DERIVED_TYPE::type_info);
}
/// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
BASE_TYPE* create() const {
return create(DERIVED_TYPE::type_info);
}
protected:
FactoryMap m_factory_map;
};
} // namespace ngraph
protected:
FactoryMap m_factory_map;
};
} // namespace ngraph

View File

@ -8,47 +8,42 @@
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/factory.hpp"
namespace ngraph
{
template <typename BASE_TYPE>
class FactoryAttributeAdapter : public VisitorAdapter
{
public:
FactoryAttributeAdapter(std::shared_ptr<BASE_TYPE>& ref)
: m_ref(ref)
{
}
namespace ngraph {
template <typename BASE_TYPE>
class FactoryAttributeAdapter : public VisitorAdapter {
public:
FactoryAttributeAdapter(std::shared_ptr<BASE_TYPE>& ref) : m_ref(ref) {}
/// \brief Hook for extra processing before other attributes
virtual bool on_start(AttributeVisitor& /* visitor */) { return true; }
/// \brief Hook for extra processing after other attributes
virtual bool on_finish(AttributeVisitor& /* visitor */) { return true; }
bool visit_attributes(AttributeVisitor& visitor) override
{
if (on_start(visitor))
{
std::string type_info_name;
uint64_t type_info_version;
if (m_ref)
{
auto& type_info = m_ref->get_type_info();
type_info_name = type_info.name;
type_info_version = type_info.version;
}
visitor.on_attribute("name", type_info_name);
visitor.on_attribute("version", type_info_version);
if (m_ref)
{
visitor.start_structure("value");
m_ref->visit_attributes(visitor);
visitor.finish_structure();
}
on_finish(visitor);
/// \brief Hook for extra processing before other attributes
virtual bool on_start(AttributeVisitor& /* visitor */) {
return true;
}
/// \brief Hook for extra processing after other attributes
virtual bool on_finish(AttributeVisitor& /* visitor */) {
return true;
}
bool visit_attributes(AttributeVisitor& visitor) override {
if (on_start(visitor)) {
std::string type_info_name;
uint64_t type_info_version;
if (m_ref) {
auto& type_info = m_ref->get_type_info();
type_info_name = type_info.name;
type_info_version = type_info.version;
}
return true;
visitor.on_attribute("name", type_info_name);
visitor.on_attribute("version", type_info_version);
if (m_ref) {
visitor.start_structure("value");
m_ref->visit_attributes(visitor);
visitor.finish_structure();
}
on_finish(visitor);
}
return true;
}
protected:
std::shared_ptr<BASE_TYPE>& m_ref;
};
} // namespace ngraph
protected:
std::shared_ptr<BASE_TYPE>& m_ref;
};
} // namespace ngraph

View File

@ -5,70 +5,64 @@
#pragma once
#include <functional>
#include <ngraph/ngraph_visibility.hpp>
#include <string>
#include <vector>
#include <ngraph/ngraph_visibility.hpp>
namespace ngraph {
namespace file_util {
/// \brief Returns the name with extension for a given path
/// \param path The path to the output file
NGRAPH_API
std::string get_file_name(const std::string& path);
namespace ngraph
{
namespace file_util
{
/// \brief Returns the name with extension for a given path
/// \param path The path to the output file
NGRAPH_API
std::string get_file_name(const std::string& path);
/// \brief Returns the file extension
/// \param path The path to the output file
NGRAPH_API
std::string get_file_ext(const std::string& path);
/// \brief Returns the file extension
/// \param path The path to the output file
NGRAPH_API
std::string get_file_ext(const std::string& path);
/// \brief Returns the directory portion of the given path
/// \param path The path to the output file
NGRAPH_API
std::string get_directory(const std::string& path);
/// \brief Returns the directory portion of the given path
/// \param path The path to the output file
NGRAPH_API
std::string get_directory(const std::string& path);
/// \brief Joins multiple paths into a single path
/// \param s1 Left side of path
/// \param s2 Right side of path
NGRAPH_API
std::string path_join(const std::string& s1, const std::string& s2);
NGRAPH_API
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3);
NGRAPH_API
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3, const std::string& s4);
/// \brief Joins multiple paths into a single path
/// \param s1 Left side of path
/// \param s2 Right side of path
NGRAPH_API
std::string path_join(const std::string& s1, const std::string& s2);
NGRAPH_API
std::string path_join(const std::string& s1, const std::string& s2, const std::string& s3);
NGRAPH_API
std::string path_join(const std::string& s1,
const std::string& s2,
const std::string& s3,
const std::string& s4);
/// \brief Iterate through files and optionally directories. Symbolic links are skipped.
/// \param path The path to iterate over
/// \param func A callback function called with each file or directory encountered
/// \param recurse Optional parameter to enable recursing through path
NGRAPH_API
void iterate_files(const std::string& path,
std::function<void(const std::string& file, bool is_dir)> func,
bool recurse = false,
bool include_links = false);
/// \brief Iterate through files and optionally directories. Symbolic links are skipped.
/// \param path The path to iterate over
/// \param func A callback function called with each file or directory encountered
/// \param recurse Optional parameter to enable recursing through path
NGRAPH_API
void iterate_files(const std::string& path,
std::function<void(const std::string& file, bool is_dir)> func,
bool recurse = false,
bool include_links = false);
/// \brief Change Linux-style path ('/') to Windows-style ('\\')
/// \param path The path to change file separator
NGRAPH_API void convert_path_win_style(std::string& path);
/// \brief Change Linux-style path ('/') to Windows-style ('\\')
/// \param path The path to change file separator
NGRAPH_API void convert_path_win_style(std::string& path);
/// \brief Conversion from wide character string to a single-byte chain.
/// \param wstr A wide-char string
/// \return A multi-byte string
NGRAPH_API std::string wstring_to_string(const std::wstring& wstr);
/// \brief Conversion from wide character string to a single-byte chain.
/// \param wstr A wide-char string
/// \return A multi-byte string
NGRAPH_API std::string wstring_to_string(const std::wstring& wstr);
/// \brief Conversion from single-byte chain to wide character string.
/// \param str A null-terminated string
/// \return A wide-char string
NGRAPH_API std::wstring multi_byte_char_to_wstring(const char* str);
/// \brief Conversion from single-byte chain to wide character string.
/// \param str A null-terminated string
/// \return A wide-char string
NGRAPH_API std::wstring multi_byte_char_to_wstring(const char* str);
/// \brief Remove path components which would allow traversing up a directory tree.
/// \param path A path to file
/// \return A sanitiazed path
NGRAPH_API std::string sanitize_path(const std::string& path);
} // namespace file_util
} // namespace ngraph
/// \brief Remove path components which would allow traversing up a directory tree.
/// \param path A path to file
/// \return A sanitiazed path
NGRAPH_API std::string sanitize_path(const std::string& path);
} // namespace file_util
} // namespace ngraph

View File

@ -20,265 +20,258 @@
#include "ngraph/op/sink.hpp"
#include "ngraph/op/util/variable.hpp"
namespace ngraph
{
/// A user-defined function.
class NGRAPH_API Function
{
public:
static constexpr DiscreteTypeInfo type_info{"Function", 0};
const DiscreteTypeInfo& get_type_info() const { return type_info; }
Function(const NodeVector& results,
const ParameterVector& parameters,
const std::string& name = "");
namespace ngraph {
/// A user-defined function.
class NGRAPH_API Function {
public:
static constexpr DiscreteTypeInfo type_info{"Function", 0};
const DiscreteTypeInfo& get_type_info() const {
return type_info;
}
Function(const NodeVector& results, const ParameterVector& parameters, const std::string& name = "");
Function(const OutputVector& results,
const ParameterVector& parameters,
const std::string& name = "");
Function(const OutputVector& results, const ParameterVector& parameters, const std::string& name = "");
Function(const std::shared_ptr<Node>& result,
const ParameterVector& parameters,
const std::string& name = "");
Function(const std::shared_ptr<Node>& result, const ParameterVector& parameters, const std::string& name = "");
Function(const ResultVector& results,
const ParameterVector& parameters,
const std::string& name = "");
Function(const ResultVector& results, const ParameterVector& parameters, const std::string& name = "");
Function(const ResultVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name = "");
Function(const ResultVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name = "");
Function(const OutputVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name = "");
Function(const OutputVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name = "");
Function(const ResultVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const ResultVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const OutputVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const OutputVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const ResultVector& results,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const ResultVector& results,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const OutputVector& results,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
Function(const OutputVector& results,
const ParameterVector& parameters,
const VariableVector& variables,
const std::string& name = "");
/// Constructs a Function. Lists of parameters and variables will be generated automatically
/// based on traversing the graph from the results.
explicit Function(const OutputVector& results, const std::string& name = "");
/// Constructs a Function. Lists of parameters and variables will be generated automatically
/// based on traversing the graph from the results.
explicit Function(const OutputVector& results, const std::string& name = "");
/// Constructs a Function. Lists of parameters and variables will be generated automatically
/// based on traversing the graph from the results and the sinks.
Function(const OutputVector& results,
const SinkVector& sinks,
const std::string& name = "");
/// Constructs a Function. Lists of parameters and variables will be generated automatically
/// based on traversing the graph from the results and the sinks.
Function(const OutputVector& results, const SinkVector& sinks, const std::string& name = "");
virtual ~Function() = default;
/// Return the number of outputs for this function.
size_t get_output_size() const;
virtual ~Function() = default;
/// Return the number of outputs for this function.
size_t get_output_size() const;
/// Return the op that generates output i
std::shared_ptr<Node> get_output_op(size_t i) const;
/// Return the op that generates output i
std::shared_ptr<Node> get_output_op(size_t i) const;
Output<Node> output(size_t i) const;
Output<Node> output(size_t i) const;
/// Return the element type of output i
const element::Type& get_output_element_type(size_t i) const;
/// Return the element type of output i
const element::Type& get_output_element_type(size_t i) const;
/// Return the shape of element i
const Shape& get_output_shape(size_t i) const;
/// Return the shape of element i
const Shape& get_output_shape(size_t i) const;
/// Return the partial shape of element i
const PartialShape& get_output_partial_shape(size_t i) const;
/// Return the partial shape of element i
const PartialShape& get_output_partial_shape(size_t i) const;
/// Check that there is a single result and return it.
std::shared_ptr<Node> get_result() const;
/// Check that there is a single result and return it.
std::shared_ptr<Node> get_result() const;
/// \brief Get the unique name of the function.
/// \returns A const reference to the function's unique name.
const std::string& get_name() const;
/// \brief Get the unique name of the function.
/// \returns A const reference to the function's unique name.
const std::string& get_name() const;
/// \brief Sets a friendly name for a function. This does not overwrite the unique name
/// of the function and is retrieved via get_friendly_name(). Used mainly for
/// debugging.
/// \param name is the friendly name to set
void set_friendly_name(const std::string& name);
/// \brief Sets a friendly name for a function. This does not overwrite the unique name
/// of the function and is retrieved via get_friendly_name(). Used mainly for
/// debugging.
/// \param name is the friendly name to set
void set_friendly_name(const std::string& name);
/// \brief Gets the friendly name for a function. If no friendly name has been set via
/// set_friendly_name then the function's unique name is returned.
/// \returns A const reference to the function's friendly name.
const std::string& get_friendly_name() const;
/// \brief Gets the friendly name for a function. If no friendly name has been set via
/// set_friendly_name then the function's unique name is returned.
/// \returns A const reference to the function's friendly name.
const std::string& get_friendly_name() const;
std::vector<std::shared_ptr<Node>> get_ops() const;
std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
void map_unordered_ops(std::function<void(Node*)> f) const;
std::vector<std::shared_ptr<Node>> get_ops() const;
std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
void map_unordered_ops(std::function<void(Node*)> f) const;
friend std::ostream& operator<<(std::ostream&, const Function&);
// updates graph and m_results list
void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
friend std::ostream& operator<<(std::ostream&, const Function&);
// updates graph and m_results list
void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
void validate_nodes_and_infer_types() const;
void validate_nodes_and_infer_types() const;
/// \brief Returns the sum of the size of all nodes in the graph plus the size of
/// all constant data. This has little value beyond comparing the relative size of
/// graphs and should not be considered the actual memory consumption of a graph.
size_t get_graph_size() const;
/// \brief Returns the sum of the size of all nodes in the graph plus the size of
/// all constant data. This has little value beyond comparing the relative size of
/// graphs and should not be considered the actual memory consumption of a graph.
size_t get_graph_size() const;
/// \brief Returns true if any of the op's defined in the function contains partial shape
bool is_dynamic() const;
/// \brief Returns true if any of the op's defined in the function contains partial shape
bool is_dynamic() const;
/// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
///
/// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
/// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
///
/// \param parameter_index The index of the parameter to replace.
/// \param parameter The parameter to substitute for the `parameter_index`th parameter.
void replace_parameter(size_t parameter_index,
const std::shared_ptr<op::Parameter>& parameter);
/// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
///
/// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
/// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
///
/// \param parameter_index The index of the parameter to replace.
/// \param parameter The parameter to substitute for the `parameter_index`th parameter.
void replace_parameter(size_t parameter_index, const std::shared_ptr<op::Parameter>& parameter);
using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
const std::vector<std::shared_ptr<Node>>& root_nodes)>;
void set_topological_sort(topological_sort_t);
using topological_sort_t =
std::function<std::vector<std::shared_ptr<Node>>(const std::vector<std::shared_ptr<Node>>& root_nodes)>;
void set_topological_sort(topological_sort_t);
virtual bool visit_attributes(AttributeVisitor& visitor);
virtual bool visit_attributes(AttributeVisitor& visitor);
/// Return the function parameters
const ParameterVector& get_parameters() const { return m_parameters; };
/// Return a list of function's outputs
const ResultVector& get_results() const { return m_results; };
/// Index for parameter, or -1
int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
/// Index for value or result referencing it, or -1
int64_t get_result_index(const Output<Node>& value) const;
/// \brief Evaluate the function on inputs, putting results in outputs.
/// \param output_tensors Tensors for the outputs to compute. One for each result
/// \param input_tensors Tensors for the inputs. One for each inputs.
/// \param evaluation_context Storage of additional settings and attributes that can be used
/// when evaluating the function. This additional information can be shared across nodes.
bool evaluate(const HostTensorVector& output_tensors,
const HostTensorVector& input_tensors,
EvaluationContext evaluation_context = EvaluationContext()) const;
/// \brief Return a list of function's sinks.
const SinkVector& get_sinks() const { return m_sinks; }
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
/// manually after all changes.
/// \param sinks new sink nodes
void add_sinks(const SinkVector& sinks);
/// \brief Delete sink node from the list of sinks. Method doesn't delete node from graph.
/// \param sink Sink to delete
void remove_sink(const std::shared_ptr<op::Sink>& sink);
/// \brief Add new Result nodes to the list. Method doesn't validate graph, it should be
/// done manually after all changes.
/// \param results new Result nodes
void add_results(const ResultVector& results);
/// \brief Delete Result node from the list of results. Method will not delete node from
/// graph.
/// \param result Result node to delete
void remove_result(const std::shared_ptr<op::Result>& result);
/// \brief Add new Parameter nodes to the list.
///
/// Method doesn't change or validate graph, it should be done manually.
/// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the
/// following steps:
/// * replace node `ReadValue` by `Parameter` in graph
/// * call add_parameter() to add new input to the list
/// * call graph validation to check correctness of changes
///
/// \param params new Parameter nodes
void add_parameters(const ParameterVector& params);
/// \brief Delete Parameter node from the list of parameters. Method will not delete node
/// from graph. You need to replace Parameter with other operation manually.
/// Attention: Indexing of parameters can be changed.
///
/// Possible use of method is to replace input by variable. For it the following steps
/// should be done:
/// * `Parameter` node should be replaced by `ReadValue`
/// * call remove_parameter(param) to remove input from the list
/// * check if any parameter indexes are saved/used somewhere, update it for all inputs
/// because indexes can be changed
/// * call graph validation to check all changes
///
/// \param param Parameter node to delete
void remove_parameter(const std::shared_ptr<op::Parameter>& param);
/// \brief Add new variables to the list. Method doesn't validate graph, it should be done
/// manually after all changes.
/// \param variables new variables to add
void add_variables(const VariableVector& variables);
/// \brief Delete variable from the list of variables.
/// Method doesn't delete nodes that used this variable from the graph.
/// \param variable Variable to delete
void remove_variable(const VariablePtr& variable);
/// \brief Return a list of function's variables.
const VariableVector& get_variables() const { return m_variables; }
/// \brief Return a variable by specified variable_id.
VariablePtr get_variable_by_id(const std::string& variable_id) const;
private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
Function& operator=(const Function&) = delete;
/// \brief Depending on the options selected,
/// checks all the Parameter/Variables are registered in the list of Function
/// parameters/variables or finds all Parameters/Variables in a function and registers them.
/// \param detect_variables If this flag is true, then it finds all Variables in a function
/// and registers them, otherwise checks all the Variables are registered.
/// \param detect_parameters If this flag is true, then it finds all Parameters in a
/// function and registers them, otherwise checks all the Parameters are registered.
void prerequirements(bool detect_variables, bool detect_parameters);
static std::atomic<size_t> m_next_instance_id;
std::string m_name;
const std::string m_unique_name;
size_t m_placement{0};
topological_sort_t m_topological_sorter;
ResultVector m_results;
// List of the nodes with side effect in graph.
// These nodes are not outputs of graph but should not be removed even if have no children.
SinkVector m_sinks;
ParameterVector m_parameters;
VariableVector m_variables;
/// Return the function parameters
const ParameterVector& get_parameters() const {
return m_parameters;
};
template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>>
: public DirectValueAccessor<std::shared_ptr<Function>>
{
public:
AttributeAdapter(std::shared_ptr<Function>& value)
: DirectValueAccessor<std::shared_ptr<Function>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
/// Return a list of function's outputs
const ResultVector& get_results() const {
return m_results;
};
} // namespace ngraph
/// Index for parameter, or -1
int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
/// Index for value or result referencing it, or -1
int64_t get_result_index(const Output<Node>& value) const;
/// \brief Evaluate the function on inputs, putting results in outputs.
/// \param output_tensors Tensors for the outputs to compute. One for each result
/// \param input_tensors Tensors for the inputs. One for each inputs.
/// \param evaluation_context Storage of additional settings and attributes that can be used
/// when evaluating the function. This additional information can be shared across nodes.
bool evaluate(const HostTensorVector& output_tensors,
const HostTensorVector& input_tensors,
EvaluationContext evaluation_context = EvaluationContext()) const;
/// \brief Return a list of function's sinks.
const SinkVector& get_sinks() const {
return m_sinks;
}
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
/// manually after all changes.
/// \param sinks new sink nodes
void add_sinks(const SinkVector& sinks);
/// \brief Delete sink node from the list of sinks. Method doesn't delete node from graph.
/// \param sink Sink to delete
void remove_sink(const std::shared_ptr<op::Sink>& sink);
/// \brief Add new Result nodes to the list. Method doesn't validate graph, it should be
/// done manually after all changes.
/// \param results new Result nodes
void add_results(const ResultVector& results);
/// \brief Delete Result node from the list of results. Method will not delete node from
/// graph.
/// \param result Result node to delete
void remove_result(const std::shared_ptr<op::Result>& result);
/// \brief Add new Parameter nodes to the list.
///
/// Method doesn't change or validate graph, it should be done manually.
/// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the
/// following steps:
/// * replace node `ReadValue` by `Parameter` in graph
/// * call add_parameter() to add new input to the list
/// * call graph validation to check correctness of changes
///
/// \param params new Parameter nodes
void add_parameters(const ParameterVector& params);
/// \brief Delete Parameter node from the list of parameters. Method will not delete node
/// from graph. You need to replace Parameter with other operation manually.
/// Attention: Indexing of parameters can be changed.
///
/// Possible use of method is to replace input by variable. For it the following steps
/// should be done:
/// * `Parameter` node should be replaced by `ReadValue`
/// * call remove_parameter(param) to remove input from the list
/// * check if any parameter indexes are saved/used somewhere, update it for all inputs
/// because indexes can be changed
/// * call graph validation to check all changes
///
/// \param param Parameter node to delete
void remove_parameter(const std::shared_ptr<op::Parameter>& param);
/// \brief Add new variables to the list. Method doesn't validate graph, it should be done
/// manually after all changes.
/// \param variables new variables to add
void add_variables(const VariableVector& variables);
/// \brief Delete variable from the list of variables.
/// Method doesn't delete nodes that used this variable from the graph.
/// \param variable Variable to delete
void remove_variable(const VariablePtr& variable);
/// \brief Return a list of function's variables.
const VariableVector& get_variables() const {
return m_variables;
}
/// \brief Return a variable by specified variable_id.
VariablePtr get_variable_by_id(const std::string& variable_id) const;
private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
Function& operator=(const Function&) = delete;
/// \brief Depending on the options selected,
/// checks all the Parameter/Variables are registered in the list of Function
/// parameters/variables or finds all Parameters/Variables in a function and registers them.
/// \param detect_variables If this flag is true, then it finds all Variables in a function
/// and registers them, otherwise checks all the Variables are registered.
/// \param detect_parameters If this flag is true, then it finds all Parameters in a
/// function and registers them, otherwise checks all the Parameters are registered.
void prerequirements(bool detect_variables, bool detect_parameters);
static std::atomic<size_t> m_next_instance_id;
std::string m_name;
const std::string m_unique_name;
size_t m_placement{0};
topological_sort_t m_topological_sorter;
ResultVector m_results;
// List of the nodes with side effect in graph.
// These nodes are not outputs of graph but should not be removed even if have no children.
SinkVector m_sinks;
ParameterVector m_parameters;
VariableVector m_variables;
};
template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public DirectValueAccessor<std::shared_ptr<Function>> {
public:
AttributeAdapter(std::shared_ptr<Function>& value) : DirectValueAccessor<std::shared_ptr<Function>>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
} // namespace ngraph

View File

@ -18,467 +18,428 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace descriptor
{
class Input;
class Output;
} // namespace descriptor
namespace ngraph {
namespace descriptor {
class Input;
class Output;
} // namespace descriptor
namespace op
{
namespace v0
{
class Parameter;
}
} // namespace op
namespace op {
namespace v0 {
class Parameter;
}
} // namespace op
NGRAPH_API
void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f);
NGRAPH_API
void traverse_nodes(const std::shared_ptr<const Function> p, std::function<void(std::shared_ptr<Node>)> f);
NGRAPH_API
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
NGRAPH_API
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
/// \brief Visit each node in a sub-graph of the entire graph
/// \param subgraph_results The output nodes of the sub-graph
/// \param f Function to execute at each node in the traversal
/// \param subgraph_params Input nodes of the sub-graph (optional)
///
/// Traverses a sub-graph starting from subgraph_results moving up
/// towards parameter nodes. Traversal stops if it hits a node in
/// subgraph_params.
///
/// Most useful for finding parameters of a graph directly from the
/// result nodes and not from function parameters or extracting a
/// subgraph relevant to the computation of certain outputs
NGRAPH_API
void traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
const NodeVector& subgraph_params = {});
/// \brief Visit each node in a sub-graph of the entire graph
/// \param subgraph_results The output nodes of the sub-graph
/// \param f Function to execute at each node in the traversal
/// \param subgraph_params Input nodes of the sub-graph (optional)
///
/// Traverses a sub-graph starting from subgraph_results moving up
/// towards parameter nodes. Traversal stops if it hits a node in
/// subgraph_params.
///
/// Most useful for finding parameters of a graph directly from the
/// result nodes and not from function parameters or extracting a
/// subgraph relevant to the computation of certain outputs
NGRAPH_API
void traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
const NodeVector& subgraph_params = {});
/// \brief Replace the node `target` with the node `replacement`, i.e.,
/// redirect all users and control dependencies of `target` to
/// `replacement`.
///
/// \param target Node to be replaced.
/// \param replacement Node to replace `target` with.
/// \param output_order Vector determines order of replacement node's outputs.
///
/// This is primarily used in graph-rewriting passes. For example, we
/// might "fuse" two Concat operations as follows:
///
/// (Step 0: Original graph)
///
/// A B
/// | |
/// v v
/// N0[Concat, concatenation_axis=3] C
/// | |
/// v v
/// N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// (Step 1: Construct replacement)
///
/// shared_ptr<Node> new_N1 = make_shared<op::Concat>({A,B,C},3);
///
/// A----------------------------------------.
/// | |
/// | B----------------)--.
/// | | | |
/// v v | |
/// N0[Concat, concatenation_axis=3] C-----)--)--.
/// | | | | |
/// v v v v v
/// N1[Concat, concatenation_axis=3] new_N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// (Step 2: Replace N1 with new_N1)
///
/// replace_node(N1, new_N1);
///
/// A----------------------------------------.
/// | |
/// | B----------------)--.
/// | | | |
/// v v | |
/// N0[Concat, concatenation_axis=3] C-----)--)--.
/// | | | | |
/// v v v v v
/// N1[Concat, concatenation_axis=3] new_N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// (Step 3: N0 and N1 are now dead, nodes will be freed)
///
/// [happens automatically, once all shared_ptrs to N1 are released]
///
/// A----------------------------------------.
/// |
/// B----------------)--.
/// | |
/// | |
/// C-----)--)--.
/// | | |
/// v v v
/// new_N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// NOTE 1: replace_node is not type-safe (the graph is not revalidated).
/// For example, the following is allowed, even if node `some_user`
/// requires an input of shape 2x2:
///
/// (Before)
/// A(shape=2x2) B(shape=3x3)
/// |
/// v
/// some_user(requires 2x2 input)
///
/// (After -- graph is now invalid)
///
/// replace_node(A, B);
///
/// A(shape=2x2) B(shape=3x3)
/// |
/// v
/// some_user(requires 2x2 input)
///
/// NOTE 2: it is possible to insert a cycle into the graph with
/// replace_node, resulting in an invalid graph. Care must be taken to
/// avoid this. One common example is when you are attempting to insert a
/// new node `M` "after"` a node `N`. For example, you might expect this
/// to work:
///
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(N);
/// replace_node(M, N);
///
/// The problem is that at replacement time, N itself is a user of M. So
/// we end up introducing a cycle as follows:
///
/// N
/// |
/// v
/// other users...
///
/// |||
/// vvv
///
/// N------------>M
/// |
/// v
/// other users...
///
/// |||
/// vvv
///
/// .----.
/// | |
/// | |
/// N `----->M
/// |
/// v
/// other users...
///
/// To avoid the cycle, a valid way to perform the above desired insertion would be,
///
/// auto new_N = N->clone_with_new_inputs(N->input_values());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M);
NGRAPH_API
void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order);
/// \brief Replace the node `target` with the node `replacement`, i.e.,
/// redirect all users and control dependencies of `target` to
/// `replacement`.
///
/// \param target Node to be replaced.
/// \param replacement Node to replace `target` with.
/// \param output_order Vector determines order of replacement node's outputs.
///
/// This is primarily used in graph-rewriting passes. For example, we
/// might "fuse" two Concat operations as follows:
///
/// (Step 0: Original graph)
///
/// A B
/// | |
/// v v
/// N0[Concat, concatenation_axis=3] C
/// | |
/// v v
/// N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// (Step 1: Construct replacement)
///
/// shared_ptr<Node> new_N1 = make_shared<op::Concat>({A,B,C},3);
///
/// A----------------------------------------.
/// | |
/// | B----------------)--.
/// | | | |
/// v v | |
/// N0[Concat, concatenation_axis=3] C-----)--)--.
/// | | | | |
/// v v v v v
/// N1[Concat, concatenation_axis=3] new_N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// (Step 2: Replace N1 with new_N1)
///
/// replace_node(N1, new_N1);
///
/// A----------------------------------------.
/// | |
/// | B----------------)--.
/// | | | |
/// v v | |
/// N0[Concat, concatenation_axis=3] C-----)--)--.
/// | | | | |
/// v v v v v
/// N1[Concat, concatenation_axis=3] new_N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// (Step 3: N0 and N1 are now dead, nodes will be freed)
///
/// [happens automatically, once all shared_ptrs to N1 are released]
///
/// A----------------------------------------.
/// |
/// B----------------)--.
/// | |
/// | |
/// C-----)--)--.
/// | | |
/// v v v
/// new_N1[Concat, concatenation_axis=3]
/// | |
/// v v
/// some_user another_user
///
/// NOTE 1: replace_node is not type-safe (the graph is not revalidated).
/// For example, the following is allowed, even if node `some_user`
/// requires an input of shape 2x2:
///
/// (Before)
/// A(shape=2x2) B(shape=3x3)
/// |
/// v
/// some_user(requires 2x2 input)
///
/// (After -- graph is now invalid)
///
/// replace_node(A, B);
///
/// A(shape=2x2) B(shape=3x3)
/// |
/// v
/// some_user(requires 2x2 input)
///
/// NOTE 2: it is possible to insert a cycle into the graph with
/// replace_node, resulting in an invalid graph. Care must be taken to
/// avoid this. One common example is when you are attempting to insert a
/// new node `M` "after"` a node `N`. For example, you might expect this
/// to work:
///
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(N);
/// replace_node(M, N);
///
/// The problem is that at replacement time, N itself is a user of M. So
/// we end up introducing a cycle as follows:
///
/// N
/// |
/// v
/// other users...
///
/// |||
/// vvv
///
/// N------------>M
/// |
/// v
/// other users...
///
/// |||
/// vvv
///
/// .----.
/// | |
/// | |
/// N `----->M
/// |
/// v
/// other users...
///
/// To avoid the cycle, a valid way to perform the above desired insertion would be,
///
/// auto new_N = N->clone_with_new_inputs(N->input_values());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M);
NGRAPH_API
void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order);
/// Replace target.outputs[i] with replacement_values[i] and transfer control dependents and
/// provenance from target to the node(s) in replacement_values.
NGRAPH_API
void replace_node(const std::shared_ptr<Node>& target, const OutputVector& replacement_values);
/// Replace target.outputs[i] with replacement_values[i] and transfer control dependents and
/// provenance from target to the node(s) in replacement_values.
NGRAPH_API
void replace_node(const std::shared_ptr<Node>& target, const OutputVector& replacement_values);
NGRAPH_API
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
NGRAPH_API
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// \brief Replace multiple nodes in a function.
/// \param f Function where replacement is taking place.
/// \param parameter_replacement_map A mapping from parameter shared pointers to parameter
/// shared pointers. For each pair (k,v) in the map, parameter
/// k is replaced by parameter v, except if k==v or k is not a
/// parameter bound by f, in which case the pair (k,v) is
/// ignored.
/// \param body_replacement_map A mapping from node shared pointers to node shared pointers.
/// For each pair (k,v) in the map, node k is replaced by node v,
/// except if k==v, the pair (k,v) is ignored.
/// Note that if k is a parameter, its users will be redirected to
/// v, but k will _not_ be replaced in the function's parameter
/// list.
///
/// Limitations:
///
/// - No check is made that the replaced nodes in `parameter_replacement_map` are actually
/// among the bound parameters of `f`. (If a parameter appears in the map that is not
/// bound by `f`, it will be silently ignored.)
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
/// `body_replacement_map`, behavior is unspecified.
NGRAPH_API
void replace_nodes(
const std::shared_ptr<Function>& f,
const std::unordered_map<std::shared_ptr<op::v0::Parameter>,
std::shared_ptr<op::v0::Parameter>>& parameter_replacement_map,
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>&
body_replacement_map);
/// \brief Replace multiple nodes in a function.
/// \param f Function where replacement is taking place.
/// \param parameter_replacement_map A mapping from parameter shared pointers to parameter
/// shared pointers. For each pair (k,v) in the map, parameter
/// k is replaced by parameter v, except if k==v or k is not a
/// parameter bound by f, in which case the pair (k,v) is
/// ignored.
/// \param body_replacement_map A mapping from node shared pointers to node shared pointers.
/// For each pair (k,v) in the map, node k is replaced by node v,
/// except if k==v, the pair (k,v) is ignored.
/// Note that if k is a parameter, its users will be redirected to
/// v, but k will _not_ be replaced in the function's parameter
/// list.
///
/// Limitations:
///
/// - No check is made that the replaced nodes in `parameter_replacement_map` are actually
/// among the bound parameters of `f`. (If a parameter appears in the map that is not
/// bound by `f`, it will be silently ignored.)
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
/// `body_replacement_map`, behavior is unspecified.
NGRAPH_API
void replace_nodes(const std::shared_ptr<Function>& f,
const std::unordered_map<std::shared_ptr<op::v0::Parameter>, std::shared_ptr<op::v0::Parameter>>&
parameter_replacement_map,
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>& body_replacement_map);
NGRAPH_API
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
NGRAPH_API
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// Topological sort of nodes needed to compute root_nodes
template <typename T>
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes)
{
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
std::vector<std::shared_ptr<Node>> result;
/// Topological sort of nodes needed to compute root_nodes
template <typename T>
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes) {
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
std::vector<std::shared_ptr<Node>> result;
for (auto& node : root_nodes)
{
nodes_to_do.push(node.get());
}
while (nodes_to_do.size() > 0)
{
Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0)
{
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
for (auto& node : root_nodes) {
nodes_to_do.push(node.get());
}
while (nodes_to_do.size() > 0) {
Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0) {
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i) {
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0) {
can_add = false;
nodes_to_do.push(dep);
}
for (auto& depptr : node->get_control_dependencies())
{
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
}
for (auto& depptr : node->get_control_dependencies()) {
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0) {
can_add = false;
nodes_to_do.push(dep);
}
if (can_add)
{
}
if (can_add) {
result.push_back(node->shared_from_this());
nodes_to_do.pop();
nodes_done.insert(node);
}
} else {
nodes_to_do.pop();
}
}
return result;
}
/// Topological sort of just nodes
template <typename T>
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes) {
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
std::unordered_set<Node*> nodes_to_emit;
std::vector<std::shared_ptr<Node>> result;
for (auto& node : nodes) {
nodes_to_emit.insert(node.get());
nodes_to_do.push(node.get());
}
// NB: Some centos versions implement std::list::size() by counting elements
size_t nodes_remaining = nodes_to_emit.size();
while (nodes_to_do.size() > 0 && nodes_remaining > 0) {
Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0) {
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i) {
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0) {
can_add = false;
nodes_to_do.push(dep);
}
}
for (auto& depptr : node->get_control_dependencies()) {
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0) {
can_add = false;
nodes_to_do.push(dep);
}
}
if (can_add) {
if (nodes_to_emit.count(node) != 0) {
result.push_back(node->shared_from_this());
nodes_to_do.pop();
nodes_done.insert(node);
nodes_remaining--;
}
}
else
{
nodes_to_do.pop();
nodes_done.insert(node);
}
}
return result;
}
/// Topological sort of just nodes
template <typename T>
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes)
{
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
std::unordered_set<Node*> nodes_to_emit;
std::vector<std::shared_ptr<Node>> result;
for (auto& node : nodes)
{
nodes_to_emit.insert(node.get());
nodes_to_do.push(node.get());
}
// NB: Some centos versions implement std::list::size() by counting elements
size_t nodes_remaining = nodes_to_emit.size();
while (nodes_to_do.size() > 0 && nodes_remaining > 0)
{
Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0)
{
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
{
can_add = false;
nodes_to_do.push(dep);
}
}
for (auto& depptr : node->get_control_dependencies())
{
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
}
if (can_add)
{
if (nodes_to_emit.count(node) != 0)
{
result.push_back(node->shared_from_this());
nodes_remaining--;
}
nodes_to_do.pop();
nodes_done.insert(node);
}
}
else
{
nodes_to_do.pop();
}
}
return result;
}
template <typename T>
void validate_nodes_and_infer_types(const T& nodes)
{
for (auto& node : subgraph_topological_sort(nodes))
{
node->revalidate_and_infer_types();
else {
nodes_to_do.pop();
}
}
return result;
}
// Check if all paths from X to a result go through Y
NGRAPH_API
bool is_post_dominated(Node* X, Node* Y);
template <typename T>
void validate_nodes_and_infer_types(const T& nodes) {
for (auto& node : subgraph_topological_sort(nodes)) {
node->revalidate_and_infer_types();
}
}
NGRAPH_API
bool is_equal_to_const_value(const std::string& const_value,
const Output<Node>& reduce_constant);
// Check if all paths from X to a result go through Y
NGRAPH_API
bool is_post_dominated(Node* X, Node* Y);
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
NGRAPH_API
std::vector<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
NGRAPH_API
bool is_equal_to_const_value(const std::string& const_value, const Output<Node>& reduce_constant);
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
NGRAPH_API
std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& node_map);
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
NGRAPH_API
std::vector<std::shared_ptr<ngraph::Node>> clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
NodeMap& node_map);
// input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops
NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func,
NodeMap& node_map);
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
NGRAPH_API
std::list<std::shared_ptr<ngraph::Node>> clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& node_map);
// input function is cloned and returned
NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
// input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops
NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func, NodeMap& node_map);
NGRAPH_API
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node);
// input function is cloned and returned
NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
NGRAPH_API
void insert_new_node_between(const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node,
const std::shared_ptr<Node>& new_node);
NGRAPH_API
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>> insert_result_parameter_split(
const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node);
NGRAPH_API
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape);
NGRAPH_API
void insert_new_node_between(const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node,
const std::shared_ptr<Node>& new_node);
NGRAPH_API
std::shared_ptr<Node> make_constant_from_string(std::string val,
const element::Type& element_type,
const Shape& shape);
NGRAPH_API
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape);
NGRAPH_API
bool is_zero(const Output<Node>& reduce_constant);
NGRAPH_API
std::shared_ptr<Node> make_constant_from_string(std::string val, const element::Type& element_type, const Shape& shape);
NGRAPH_API
NodeVector get_subgraph_outputs(const NodeVector& nodes,
const NodeVector& exclusions,
bool ignore_unused = false,
bool ignore_output_duplicates = true);
NGRAPH_API
bool is_zero(const Output<Node>& reduce_constant);
// Extract sub-graph computing the `results`. Stops backward traversal at either a Parameter
// node
// or a node that belongs to args
NGRAPH_API
NodeVector extract_subgraph(const NodeVector& results, const NodeVector& args);
NGRAPH_API
NodeVector get_subgraph_outputs(const NodeVector& nodes,
const NodeVector& exclusions,
bool ignore_unused = false,
bool ignore_output_duplicates = true);
NGRAPH_API
bool is_one(const Output<Node>& reduce_constant);
// Extract sub-graph computing the `results`. Stops backward traversal at either a Parameter
// node
// or a node that belongs to args
NGRAPH_API
NodeVector extract_subgraph(const NodeVector& results, const NodeVector& args);
NGRAPH_API
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
NGRAPH_API
bool is_one(const Output<Node>& reduce_constant);
// Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node`
NGRAPH_API
bool is_used(Node* node);
NGRAPH_API
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
// Returns count of `node` users that are still live in the graph
NGRAPH_API
size_t get_user_count(Node* node);
// Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node`
NGRAPH_API
bool is_used(Node* node);
// Return true if a node's user could potentially overwrite
// the output of this node with in-place kernels
NGRAPH_API
bool possibly_overwritten(Node* node);
// Returns count of `node` users that are still live in the graph
NGRAPH_API
size_t get_user_count(Node* node);
NGRAPH_API
bool is_strided(const Strides& strides);
// Return true if a node's user could potentially overwrite
// the output of this node with in-place kernels
NGRAPH_API
bool possibly_overwritten(Node* node);
NGRAPH_API
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
NGRAPH_API
bool is_strided(const Strides& strides);
NGRAPH_API
void plot_graph(
std::shared_ptr<Function> f,
const std::string& filename,
std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr);
NGRAPH_API
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
/// \return A vector containing handles for each input of dst that is connected to an output
/// of `src`.
NGRAPH_API
std::vector<Input<Node>> get_inputs_from(Node& src, Node& dst);
/// \return A vector containing a handle for each output of src that is connected to an input
/// of `dst`.
NGRAPH_API
std::vector<Output<Node>> get_outputs_to(Node& src, Node& dst);
NGRAPH_API
void plot_graph(std::shared_ptr<Function> f,
const std::string& filename,
std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr);
/// Checks the func for graph cycles starting from results going backwards, then from parameters
/// going forward.
/// It returns true if a cycle is found and the first cycle encountered.
NGRAPH_API
bool check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle);
/// \return A vector containing handles for each input of dst that is connected to an output
/// of `src`.
NGRAPH_API
std::vector<Input<Node>> get_inputs_from(Node& src, Node& dst);
/// \return A vector containing a handle for each output of src that is connected to an input
/// of `dst`.
NGRAPH_API
std::vector<Output<Node>> get_outputs_to(Node& src, Node& dst);
NGRAPH_API
bool replace_output_update_name(Output<Node> node, const Output<Node>& node_input);
/// Checks the func for graph cycles starting from results going backwards, then from parameters
/// going forward.
/// It returns true if a cycle is found and the first cycle encountered.
NGRAPH_API
bool check_for_cycles(const ngraph::Function* func, ngraph::NodeVector& cycle_nodes, bool& is_bkwd_cycle);
NGRAPH_API
bool replace_node_update_name(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
} // namespace ngraph
NGRAPH_API
bool replace_output_update_name(Output<Node> node, const Output<Node>& node_input);
NGRAPH_API
bool replace_node_update_name(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
} // namespace ngraph

View File

@ -12,99 +12,109 @@
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
/// \brief Interval arithmetic
///
/// An interval is the set of integers from m_min_val through m_max_val.
/// The value s_max acts like infinity. The
/// addition, subtraction, or multiplication of intervals is the smallest interval
/// containing the sums, differences, or products of elements of the two intervals. An empty
/// interval is canonicalized to [s_max, s_max].
class NGRAPH_API Interval
{
public:
using value_type = std::int64_t;
using size_type = std::uint64_t;
namespace ngraph {
/// \brief Interval arithmetic
///
/// An interval is the set of integers from m_min_val through m_max_val.
/// The value s_max acts like infinity. The
/// addition, subtraction, or multiplication of intervals is the smallest interval
/// containing the sums, differences, or products of elements of the two intervals. An empty
/// interval is canonicalized to [s_max, s_max].
class NGRAPH_API Interval {
public:
using value_type = std::int64_t;
using size_type = std::uint64_t;
/// \brief Interval of everything
Interval() = default;
/// \brief Copy constructor
Interval(const Interval& interval) = default;
/// \brief Interval of everything
Interval() = default;
/// \brief Copy constructor
Interval(const Interval& interval) = default;
/// \brief Closed interval {x|min_val <= x <= max_val}
Interval(value_type min_val, value_type max_val);
/// \brief Closed interval {x|min_val <= x <= max_val}
Interval(value_type min_val, value_type max_val);
/// \brief Single-valued interval; just contains val
Interval(value_type val);
/// \brief Single-valued interval; just contains val
Interval(value_type val);
Interval& operator=(const Interval& interval) = default;
Interval& operator=(const Interval& interval) = default;
/// \brief The number of elements in the interval. Zero if max < min.
size_type size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
/// \brief The number of elements in the interval. Zero if max < min.
size_type size() const {
if (m_max_val == s_max) {
return m_min_val == s_max ? 0 : s_max;
}
/// \brief Returns true if the interval has no elements
bool empty() const { return m_min_val == s_max; }
/// \brief the inclusive lower bound of the interval
value_type get_min_val() const { return m_min_val; }
/// \brief Set the inclusive lower bound of the interval
void set_min_val(value_type val) { m_min_val = val; }
/// \brief the inclusive upper bound of the interval
value_type get_max_val() const { return m_max_val; }
/// \brief Set the inclusive upper bound of the interval
void set_max_val(value_type val) { m_max_val = val; }
/// \brief True if the upper bound is finite
bool has_upper_bound() const { return m_max_val != s_max; }
/// \brief True if min and max bounds match
bool operator==(const Interval& interval) const;
bool operator!=(const Interval& interval) const;
return m_max_val - m_min_val + 1;
}
/// \brief Returns true if the interval has no elements
bool empty() const {
return m_min_val == s_max;
}
/// \brief the inclusive lower bound of the interval
value_type get_min_val() const {
return m_min_val;
}
/// \brief Set the inclusive lower bound of the interval
void set_min_val(value_type val) {
m_min_val = val;
}
/// \brief the inclusive upper bound of the interval
value_type get_max_val() const {
return m_max_val;
}
/// \brief Set the inclusive upper bound of the interval
void set_max_val(value_type val) {
m_max_val = val;
}
/// \brief True if the upper bound is finite
bool has_upper_bound() const {
return m_max_val != s_max;
}
/// \brief True if min and max bounds match
bool operator==(const Interval& interval) const;
bool operator!=(const Interval& interval) const;
/// \brief The interval whose elements are a sum of an element from each interval
Interval operator+(const Interval& interval) const;
/// \brief The interval whose elements are a sum of an element from each interval
Interval operator+(const Interval& interval) const;
/// \brief Extend this interval to sums of elements in this interval and interval
Interval& operator+=(const Interval& interval);
/// \brief Extend this interval to sums of elements in this interval and interval
Interval& operator+=(const Interval& interval);
/// \brief The interval whose elements are a difference of an element from each interval
Interval operator-(const Interval& interval) const;
/// \brief The interval whose elements are a difference of an element from each interval
Interval operator-(const Interval& interval) const;
/// \brief Extend this interval to differences of elements in this interval and interval
Interval& operator-=(const Interval& interval);
/// \brief Extend this interval to differences of elements in this interval and interval
Interval& operator-=(const Interval& interval);
/// \brief The smallest interval whose elements are a product of an element from each
/// interval
Interval operator*(const Interval& interval) const;
/// \brief The smallest interval whose elements are a product of an element from each
/// interval
Interval operator*(const Interval& interval) const;
/// \brief Extend this interval to products of elements in this interval and interval
Interval& operator*=(const Interval& interval);
/// \brief Extend this interval to products of elements in this interval and interval
Interval& operator*=(const Interval& interval);
/// \brief The interval that is the intersection of this interval and interval
Interval operator&(const Interval& interval) const;
/// \brief The interval that is the intersection of this interval and interval
Interval operator&(const Interval& interval) const;
/// \brief Change this interval to only include elements also in interval
Interval& operator&=(const Interval& interval);
/// \brief Change this interval to only include elements also in interval
Interval& operator&=(const Interval& interval);
/// \brief True if this interval includes value
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
/// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const;
/// \brief True if this interval includes value
bool contains(value_type value) const {
return m_min_val <= value && value <= m_max_val;
}
/// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const;
/// \brief The value used for no upper bound
static constexpr value_type s_max{std::numeric_limits<value_type>::max()};
/// \brief The value used for no upper bound
static constexpr value_type s_max{std::numeric_limits<value_type>::max()};
protected:
void canonicalize();
protected:
void canonicalize();
value_type m_min_val{0};
value_type m_max_val{s_max};
};
value_type m_min_val{0};
value_type m_max_val{s_max};
};
NGRAPH_API
std::ostream& operator<<(std::ostream& str, const Interval& interval);
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream& str, const Interval& interval);
} // namespace ngraph

View File

@ -13,161 +13,141 @@
#include <sstream>
#include <stdexcept>
#if defined(__linux) || defined(__APPLE__)
#include <sys/time.h>
#include <unistd.h>
# include <sys/time.h>
# include <unistd.h>
#endif
#include <ngraph/ngraph_visibility.hpp>
#include <vector>
#include <ngraph/ngraph_visibility.hpp>
namespace ngraph {
class ConstString {
public:
template <size_t SIZE>
constexpr ConstString(const char (&p)[SIZE]) : m_string(p),
m_size(SIZE) {}
namespace ngraph
{
class ConstString
{
public:
template <size_t SIZE>
constexpr ConstString(const char (&p)[SIZE])
: m_string(p)
, m_size(SIZE)
{
}
constexpr char operator[](size_t i) const
{
return i < m_size ? m_string[i] : throw std::out_of_range("");
}
constexpr const char* get_ptr(size_t offset) const
{
return offset < m_size ? &m_string[offset] : m_string;
}
constexpr size_t size() const { return m_size; }
private:
const char* m_string;
size_t m_size;
};
constexpr const char* find_last(ConstString s, size_t offset, char ch)
{
return offset == 0
? s.get_ptr(0)
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
constexpr char operator[](size_t i) const {
return i < m_size ? m_string[i] : throw std::out_of_range("");
}
constexpr const char* get_ptr(size_t offset) const {
return offset < m_size ? &m_string[offset] : m_string;
}
constexpr size_t size() const {
return m_size;
}
constexpr const char* find_last(ConstString s, char ch)
{
return find_last(s, s.size() - 1, ch);
private:
const char* m_string;
size_t m_size;
};
constexpr const char* find_last(ConstString s, size_t offset, char ch) {
return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
}
constexpr const char* find_last(ConstString s, char ch) {
return find_last(s, s.size() - 1, ch);
}
constexpr const char* get_file_name(ConstString s) {
return find_last(s, '/');
}
constexpr const char* trim_file_name(ConstString root, ConstString s) {
return s.get_ptr(root.size());
}
enum class LOG_TYPE {
_LOG_TYPE_ERROR,
_LOG_TYPE_WARNING,
_LOG_TYPE_INFO,
_LOG_TYPE_DEBUG,
};
class NGRAPH_API LogHelper {
public:
LogHelper(LOG_TYPE, const char* file, int line, std::function<void(const std::string&)> m_handler_func);
~LogHelper();
std::ostream& stream() {
return m_stream;
}
constexpr const char* get_file_name(ConstString s) { return find_last(s, '/'); }
constexpr const char* trim_file_name(ConstString root, ConstString s)
{
return s.get_ptr(root.size());
}
enum class LOG_TYPE
{
_LOG_TYPE_ERROR,
_LOG_TYPE_WARNING,
_LOG_TYPE_INFO,
_LOG_TYPE_DEBUG,
};
private:
std::function<void(const std::string&)> m_handler_func;
std::stringstream m_stream;
};
class NGRAPH_API LogHelper
{
public:
LogHelper(LOG_TYPE,
const char* file,
int line,
std::function<void(const std::string&)> m_handler_func);
~LogHelper();
class Logger {
friend class LogHelper;
std::ostream& stream() { return m_stream; }
public:
static void set_log_path(const std::string& path);
static void start();
static void stop();
private:
std::function<void(const std::string&)> m_handler_func;
std::stringstream m_stream;
};
private:
static void log_item(const std::string& s);
static void process_event(const std::string& s);
static void thread_entry(void* param);
static std::string m_log_path;
static std::deque<std::string> m_queue;
};
class Logger
{
friend class LogHelper;
public:
static void set_log_path(const std::string& path);
static void start();
static void stop();
private:
static void log_item(const std::string& s);
static void process_event(const std::string& s);
static void thread_entry(void* param);
static std::string m_log_path;
static std::deque<std::string> m_queue;
};
NGRAPH_API
void default_logger_handler_func(const std::string& s);
NGRAPH_API
void default_logger_handler_func(const std::string& s);
#ifndef PROJECT_ROOT_DIR
#define PROJECT_ROOT_DIR ""
# define PROJECT_ROOT_DIR ""
#endif
#define NGRAPH_ERR \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_ERROR, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
#define NGRAPH_ERR \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_ERROR, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
.stream()
#define NGRAPH_WARN \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_WARNING, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
#define NGRAPH_WARN \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_WARNING, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
.stream()
#define NGRAPH_INFO \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_INFO, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
#define NGRAPH_INFO \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_INFO, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
.stream()
#ifdef NGRAPH_DEBUG_ENABLE
#define NGRAPH_DEBUG \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_DEBUG, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
.stream()
# define NGRAPH_DEBUG \
ngraph::LogHelper(ngraph::LOG_TYPE::_LOG_TYPE_DEBUG, \
ngraph::trim_file_name(PROJECT_ROOT_DIR, __FILE__), \
__LINE__, \
ngraph::default_logger_handler_func) \
.stream()
#else
struct NullLogger
{
};
struct NullLogger {};
template <typename T>
NullLogger&& operator<<(NullLogger&& logger, T&&)
{
return std::move(logger);
}
template <typename T>
NullLogger&& operator<<(NullLogger&& logger, T&&) {
return std::move(logger);
}
template <typename T>
NullLogger&& operator<<(NullLogger&& logger, const T&)
{
return std::move(logger);
}
template <typename T>
NullLogger&& operator<<(NullLogger&& logger, const T&) {
return std::move(logger);
}
inline NullLogger&&
operator<<(NullLogger&& logger,
std::basic_ostream<char, std::char_traits<char>>& (&)(std::basic_ostream<
char,
std::char_traits<char>>&))
{
return std::move(logger);
}
inline NullLogger&& operator<<(
NullLogger&& logger,
std::basic_ostream<char, std::char_traits<char>>& (&)(std::basic_ostream<char, std::char_traits<char>>&)) {
return std::move(logger);
}
#define NGRAPH_DEBUG \
::ngraph::NullLogger {}
# define NGRAPH_DEBUG \
::ngraph::NullLogger {}
#endif
} // namespace ngraph
} // namespace ngraph

View File

@ -11,28 +11,27 @@
#include <string>
#ifdef IN_NGRAPH_LIBRARY
#error("ngraph.hpp is for external use only")
# error("ngraph.hpp is for external use only")
#endif
#include <ngraph/ngraph_visibility.hpp>
extern "C" NGRAPH_API const char* get_ngraph_version_string();
namespace ngraph
{
/// \brief Function to query parsed version information of the version of ngraph which
/// contains this function. Version information strictly follows Semantic Versioning
/// http://semver.org
/// \param major Returns the major part of the version
/// \param minor Returns the minor part of the version
/// \param patch Returns the patch part of the version
/// \param extra Returns the extra part of the version. This includes everything following
/// the patch version number.
///
/// \note Throws a runtime_error if there is an error during parsing
NGRAPH_API
void get_version(size_t& major, size_t& minor, size_t& patch, std::string& extra);
} // namespace ngraph
namespace ngraph {
/// \brief Function to query parsed version information of the version of ngraph which
/// contains this function. Version information strictly follows Semantic Versioning
/// http://semver.org
/// \param major Returns the major part of the version
/// \param minor Returns the minor part of the version
/// \param patch Returns the patch part of the version
/// \param extra Returns the extra part of the version. This includes everything following
/// the patch version number.
///
/// \note Throws a runtime_error if there is an error during parsing
NGRAPH_API
void get_version(size_t& major, size_t& minor, size_t& patch, std::string& extra);
} // namespace ngraph
/// \namespace ngraph
/// \brief The Intel nGraph C++ API.

View File

@ -9,27 +9,26 @@
// (or does nothing for static build)
#ifdef _WIN32
#pragma warning(disable : 4251)
#pragma warning(disable : 4275)
# pragma warning(disable : 4251)
# pragma warning(disable : 4275)
#endif
#ifdef NGRAPH_STATIC_LIBRARY // defined if we are building or calling NGRAPH as a static library
#define NGRAPH_API
#ifdef NGRAPH_STATIC_LIBRARY // defined if we are building or calling NGRAPH as a static library
# define NGRAPH_API
#else
#ifdef ngraph_EXPORTS // defined if we are building the NGRAPH DLL (instead of using it)
#define NGRAPH_API NGRAPH_HELPER_DLL_EXPORT
#else
#define NGRAPH_API NGRAPH_HELPER_DLL_IMPORT
#endif // ngraph_EXPORTS
#endif // NGRAPH_STATIC_LIBRARY
# ifdef ngraph_EXPORTS // defined if we are building the NGRAPH DLL (instead of using it)
# define NGRAPH_API NGRAPH_HELPER_DLL_EXPORT
# else
# define NGRAPH_API NGRAPH_HELPER_DLL_IMPORT
# endif // ngraph_EXPORTS
#endif // NGRAPH_STATIC_LIBRARY
#ifndef ENABLE_UNICODE_PATH_SUPPORT
#ifdef _WIN32
#if defined __INTEL_COMPILER || defined _MSC_VER
#define ENABLE_UNICODE_PATH_SUPPORT
#endif
#elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || \
defined(__clang__)
#define ENABLE_UNICODE_PATH_SUPPORT
#endif
# ifdef _WIN32
# if defined __INTEL_COMPILER || defined _MSC_VER
# define ENABLE_UNICODE_PATH_SUPPORT
# endif
# elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || defined(__clang__)
# define ENABLE_UNICODE_PATH_SUPPORT
# endif
#endif

File diff suppressed because it is too large Load Diff

View File

@ -13,116 +13,111 @@
#include "ngraph/type/element_type.hpp"
#include "ngraph/variant.hpp"
namespace ngraph
{
class Node;
namespace ngraph {
class Node;
template <typename NodeType>
class Output;
template <typename NodeType>
class Output;
template <typename NodeType>
class Input
{
};
template <typename NodeType>
class Input {};
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(Node* node, size_t index);
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<Node> {
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(Node* node, size_t index);
/// \return A pointer to the node referenced by this input handle.
Node* get_node() const;
/// \return The index of the input referred to by this input handle.
size_t get_index() const;
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const;
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const;
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const;
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const;
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const;
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
/// \return A pointer to the node referenced by this input handle.
Node* get_node() const;
/// \return The index of the input referred to by this input handle.
size_t get_index() const;
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const;
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const;
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const;
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const;
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const;
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
/// \brief Replaces the source output of this input.
/// \param new_source_output A handle for the output that will replace this input's source.
void replace_source_output(const Output<Node>& new_source_output) const;
/// \brief Replaces the source output of this input.
/// \param new_source_output A handle for the output that will replace this input's source.
void replace_source_output(const Output<Node>& new_source_output) const;
/// \return The reference to runtime info map
RTMap& get_rt_info();
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
/// \return The reference to runtime info map
RTMap& get_rt_info();
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
bool operator>(const Input& other) const;
bool operator<=(const Input& other) const;
bool operator>=(const Input& other) const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
bool operator>(const Input& other) const;
bool operator<=(const Input& other) const;
bool operator>=(const Input& other) const;
private:
Node* const m_node;
const size_t m_index;
};
private:
Node* const m_node;
const size_t m_index;
};
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<const Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(const Node* node, size_t index);
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<const Node> {
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(const Node* node, size_t index);
/// \return A pointer to the node referenced by this input handle.
const Node* get_node() const;
/// \return The index of the input referred to by this input handle.
size_t get_index() const;
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const;
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const;
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const;
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const;
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const;
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
/// \return A pointer to the node referenced by this input handle.
const Node* get_node() const;
/// \return The index of the input referred to by this input handle.
size_t get_index() const;
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const;
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const;
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const;
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const;
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const;
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
bool operator>(const Input& other) const;
bool operator<=(const Input& other) const;
bool operator>=(const Input& other) const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
bool operator>(const Input& other) const;
bool operator<=(const Input& other) const;
bool operator>=(const Input& other) const;
private:
const Node* const m_node;
const size_t m_index;
};
private:
const Node* const m_node;
const size_t m_index;
};
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<Node>& input);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<const Node>& input);
} // namespace ngraph
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<Node>& input);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<const Node>& input);
} // namespace ngraph

View File

@ -14,169 +14,158 @@
#include "ngraph/type/element_type.hpp"
#include "ngraph/variant.hpp"
namespace ngraph
{
class Node;
namespace ngraph {
class Node;
template <typename NodeType>
class Input;
template <typename NodeType>
class Input;
template <typename NodeType>
class Output
{
};
template <typename NodeType>
class Output {};
/// \brief A handle for one of a node's outputs.
template <>
class NGRAPH_API Output<Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(Node* node, size_t index);
/// \brief A handle for one of a node's outputs.
template <>
class NGRAPH_API Output<Node> {
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(Node* node, size_t index);
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<Node>& node, size_t index);
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<Node>& node, size_t index);
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node)
: Output(node ? node->get_default_output() : Output<Node>())
{
}
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node) : Output(node ? node->get_default_output() : Output<Node>()) {}
/// A null output
Output() = default;
/// A null output
Output() = default;
void reset();
void reset();
/// This output position for a different node
Output<Node> for_node(const std::shared_ptr<Node>& node);
/// \return A pointer to the node referred to by this output handle.
Node* get_node() const;
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<Node> get_node_shared_ptr() const;
/// This output position for a different node
Output<Node> for_node(const std::shared_ptr<Node>& node);
/// \return A pointer to the node referred to by this output handle.
Node* get_node() const;
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<Node> get_node_shared_ptr() const;
/// \return The index of the output referred to by this output handle.
size_t get_index() const;
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const;
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const;
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const;
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const;
/// \return The index of the output referred to by this output handle.
size_t get_index() const;
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const;
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const;
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const;
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const;
/// \return The reference to runtime info map
RTMap& get_rt_info();
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
/// \return The reference to runtime info map
RTMap& get_rt_info();
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
/// \brief Removes a target input from the output referenced by this output handle.
/// \param target_input The target input to remove.
///
// TODO(amprocte): Investigate whether this really ought to be public.
void remove_target_input(const Input<Node>& target_input) const;
/// \brief Removes a target input from the output referenced by this output handle.
/// \param target_input The target input to remove.
///
// TODO(amprocte): Investigate whether this really ought to be public.
void remove_target_input(const Input<Node>& target_input) const;
/// \brief Replace all users of this value with replacement
void replace(const Output<Node>& replacement);
/// \brief Replace all users of this value with replacement
void replace(const Output<Node>& replacement);
bool operator==(const Output& other) const;
bool operator!=(const Output& other) const;
bool operator<(const Output& other) const;
bool operator>(const Output& other) const;
bool operator<=(const Output& other) const;
bool operator>=(const Output& other) const;
bool operator==(const Output& other) const;
bool operator!=(const Output& other) const;
bool operator<(const Output& other) const;
bool operator>(const Output& other) const;
bool operator<=(const Output& other) const;
bool operator>=(const Output& other) const;
private:
std::shared_ptr<Node> m_node;
size_t m_index{0};
};
private:
std::shared_ptr<Node> m_node;
size_t m_index{0};
};
template <>
class NGRAPH_API Output<const Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(const Node* node, size_t index);
template <>
class NGRAPH_API Output<const Node> {
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(const Node* node, size_t index);
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<const Node>& node, size_t index);
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<const Node>& node, size_t index);
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node)
: Output(node ? node->get_default_output() : Output<const Node>())
{
}
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node) : Output(node ? node->get_default_output() : Output<const Node>()) {}
/// A null output
Output() = default;
/// A null output
Output() = default;
void reset();
void reset();
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node);
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node);
/// \return A pointer to the node referred to by this output handle.
const Node* get_node() const;
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<const Node> get_node_shared_ptr() const;
/// \return The index of the output referred to by this output handle.
size_t get_index() const;
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const;
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const;
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const;
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const;
/// \return A pointer to the node referred to by this output handle.
const Node* get_node() const;
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<const Node> get_node_shared_ptr() const;
/// \return The index of the output referred to by this output handle.
size_t get_index() const;
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const;
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const;
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const;
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const;
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
bool operator==(const Output& other) const;
bool operator!=(const Output& other) const;
bool operator<(const Output& other) const;
bool operator>(const Output& other) const;
bool operator<=(const Output& other) const;
bool operator>=(const Output& other) const;
bool operator==(const Output& other) const;
bool operator!=(const Output& other) const;
bool operator<(const Output& other) const;
bool operator>(const Output& other) const;
bool operator<=(const Output& other) const;
bool operator>=(const Output& other) const;
private:
std::shared_ptr<const Node> m_node;
size_t m_index{0};
};
private:
std::shared_ptr<const Node> m_node;
size_t m_index{0};
};
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
} // namespace ngraph
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
} // namespace ngraph

View File

@ -8,39 +8,37 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise absolute value operation.
///
class NGRAPH_API Abs : public util::UnaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Abs", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an absolute value operation.
Abs() = default;
bool visit_attributes(AttributeVisitor&) override { return true; }
/// \brief Constructs an absolute value operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Abs(const Output<Node>& arg);
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise absolute value operation.
///
class NGRAPH_API Abs : public util::UnaryElementwiseArithmetic {
public:
static constexpr NodeTypeInfo type_info{"Abs", 0};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs an absolute value operation.
Abs() = default;
bool visit_attributes(AttributeVisitor&) override {
return true;
}
/// \brief Constructs an absolute value operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Abs(const Output<Node>& arg);
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Abs;
} // namespace op
} // namespace ngraph
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Abs;
} // namespace op
} // namespace ngraph

View File

@ -8,37 +8,35 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise inverse cosine (arccos) operation.
///
class NGRAPH_API Acos : public util::UnaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Acos", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an arccos operation.
Acos() = default;
/// \brief Constructs an arccos operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor&) override { return true; }
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Acos;
} // namespace op
} // namespace ngraph
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise inverse cosine (arccos) operation.
///
class NGRAPH_API Acos : public util::UnaryElementwiseArithmetic {
public:
static constexpr NodeTypeInfo type_info{"Acos", 0};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs an arccos operation.
Acos() = default;
/// \brief Constructs an arccos operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor&) override {
return true;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Acos;
} // namespace op
} // namespace ngraph

View File

@ -8,38 +8,34 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Elementwise inverse hyperbolic cos operation.
///
class NGRAPH_API Acosh : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Elementwise inverse hyperbolic cos operation.
///
class NGRAPH_API Acosh : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an Acosh operation.
Acosh() = default;
/// \brief Constructs an Acosh operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Acosh(const Output<Node>& arg);
/// \brief Constructs an Acosh operation.
Acosh() = default;
/// \brief Constructs an Acosh operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Acosh(const Output<Node>& arg);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
using v3::Acosh;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override {
return true;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
using v3::Acosh;
} // namespace op
} // namespace ngraph

View File

@ -7,37 +7,32 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v8
{
/// \brief Adaptive average pooling operation.
///
class NGRAPH_API AdaptiveAvgPool : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v8 {
/// \brief Adaptive average pooling operation.
///
class NGRAPH_API AdaptiveAvgPool : public Op {
public:
NGRAPH_RTTI_DECLARATION;
AdaptiveAvgPool() = default;
AdaptiveAvgPool() = default;
///
/// \brief Constructs adaptive average pooling operation.
///
/// \param data Input data
///
/// \param output_shape 1D tensor describing output shape for spatial
/// dimensions.
///
AdaptiveAvgPool(const Output<Node>& data, const Output<Node>& output_shape);
///
/// \brief Constructs adaptive average pooling operation.
///
/// \param data Input data
///
/// \param output_shape 1D tensor describing output shape for spatial
/// dimensions.
///
AdaptiveAvgPool(const Output<Node>& data, const Output<Node>& output_shape);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v8
} // namespace op
} // namespace ngraph
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v8
} // namespace op
} // namespace ngraph

View File

@ -7,48 +7,44 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v8
{
/// \brief Adaptive max pooling operation.
///
class NGRAPH_API AdaptiveMaxPool : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v8 {
/// \brief Adaptive max pooling operation.
///
class NGRAPH_API AdaptiveMaxPool : public Op {
public:
NGRAPH_RTTI_DECLARATION;
AdaptiveMaxPool() = default;
AdaptiveMaxPool() = default;
///
/// \brief Constructs adaptive max pooling operation.
///
/// \param data Input data
///
/// \param output_shape 1D tensor describing output shape for spatial
/// dimensions.
///
/// \param index_element_type Specifies the output tensor type for indices
/// output
///
AdaptiveMaxPool(
const Output<Node>& data,
///
/// \brief Constructs adaptive max pooling operation.
///
/// \param data Input data
///
/// \param output_shape 1D tensor describing output shape for spatial
/// dimensions.
///
/// \param index_element_type Specifies the output tensor type for indices
/// output
///
AdaptiveMaxPool(const Output<Node>& data,
const Output<Node>& output_shape,
const ngraph::element::Type& index_element_type = ngraph::element::i64);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
element::Type get_index_element_type() const { return m_index_element_type; }
element::Type get_index_element_type() const {
return m_index_element_type;
}
protected:
ngraph::element::Type m_index_element_type = ngraph::element::i64;
};
} // namespace v8
} // namespace op
} // namespace ngraph
protected:
ngraph::element::Type m_index_element_type = ngraph::element::i64;
};
} // namespace v8
} // namespace op
} // namespace ngraph

View File

@ -8,50 +8,40 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise addition operation.
///
class NGRAPH_API Add : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Elementwise addition operation.
///
class NGRAPH_API Add : public util::BinaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an uninitialized addition operation
Add()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an uninitialized addition operation
Add() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) {}
/// \brief Constructs an addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification. Default is Numpy-style
/// implicit broadcasting.
///
/// Output `[d0, ...]`
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs an addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification. Default is Numpy-style
/// implicit broadcasting.
///
/// Output `[d0, ...]`
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -8,43 +8,36 @@
#include "ngraph/op/util/binary_elementwise_logical.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise logical-and operation.
///
class NGRAPH_API LogicalAnd : public util::BinaryElementwiseLogical
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a logical-and operation.
LogicalAnd() = default;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Elementwise logical-and operation.
///
class NGRAPH_API LogicalAnd : public util::BinaryElementwiseLogical {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a logical-and operation.
LogicalAnd() = default;
/// \brief Constructs a logical-and operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
LogicalAnd(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a logical-and operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
LogicalAnd(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -8,38 +8,36 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise inverse sine (arcsin) operation.
///
class NGRAPH_API Asin : public util::UnaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Asin", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an arcsin operation.
Asin() = default;
/// \brief Constructs an arcsin operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Asin(const Output<Node>& arg);
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise inverse sine (arcsin) operation.
///
class NGRAPH_API Asin : public util::UnaryElementwiseArithmetic {
public:
static constexpr NodeTypeInfo type_info{"Asin", 0};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs an arcsin operation.
Asin() = default;
/// \brief Constructs an arcsin operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Asin(const Output<Node>& arg);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Asin;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override {
return true;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Asin;
} // namespace op
} // namespace ngraph

View File

@ -8,38 +8,34 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Elementwise inverse hyperbolic sin operation.
///
class NGRAPH_API Asinh : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Elementwise inverse hyperbolic sin operation.
///
class NGRAPH_API Asinh : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an Asinh operation.
Asinh() = default;
/// \brief Constructs an Asinh operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Asinh(const Output<Node>& arg);
/// \brief Constructs an Asinh operation.
Asinh() = default;
/// \brief Constructs an Asinh operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Asinh(const Output<Node>& arg);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
using v3::Asinh;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override {
return true;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
using v3::Asinh;
} // namespace op
} // namespace ngraph

View File

@ -8,86 +8,73 @@
#include "ngraph/op/util/variable.hpp"
#include "ngraph/op/util/variable_extension.hpp"
namespace ngraph
{
namespace op
{
class NGRAPH_API AssignBase : public Sink, public VariableExtension
{
public:
NGRAPH_RTTI_DECLARATION;
AssignBase() = default;
/// \brief Constructs an AssignBase operation.
explicit AssignBase(const OutputVector& arguments)
: Sink(arguments)
{
}
};
namespace ngraph {
namespace op {
class NGRAPH_API AssignBase : public Sink, public VariableExtension {
public:
NGRAPH_RTTI_DECLARATION;
AssignBase() = default;
/// \brief Constructs an AssignBase operation.
explicit AssignBase(const OutputVector& arguments) : Sink(arguments) {}
};
namespace v3
{
/// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public AssignBase
{
public:
NGRAPH_RTTI_DECLARATION;
Assign() = default;
namespace v3 {
/// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public AssignBase {
public:
NGRAPH_RTTI_DECLARATION;
Assign() = default;
/// \brief Constructs an Assign operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable_id identifier of the variable to be updated.
Assign(const Output<Node>& new_value, const std::string& variable_id);
/// \brief Constructs an Assign operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable_id identifier of the variable to be updated.
Assign(const Output<Node>& new_value, const std::string& variable_id);
void validate_and_infer_types() override;
std::string get_variable_id() const override { return m_variable_id; }
void validate_and_infer_types() override;
std::string get_variable_id() const override {
return m_variable_id;
}
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
private:
std::string m_variable_id;
};
} // namespace v3
namespace v6
{
/// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public AssignBase
{
public:
NGRAPH_RTTI_DECLARATION;
Assign() = default;
private:
std::string m_variable_id;
};
} // namespace v3
namespace v6 {
/// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public AssignBase {
public:
NGRAPH_RTTI_DECLARATION;
Assign() = default;
/// \brief Constructs an Assign operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable Class for storing and synchronizing element types, shapes and
/// identifiers
/// between pairs of Assign/ReadValue nodes.
Assign(const Output<Node>& new_value, const std::shared_ptr<Variable>& variable);
/// \brief Constructs an Assign operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable Class for storing and synchronizing element types, shapes and
/// identifiers
/// between pairs of Assign/ReadValue nodes.
Assign(const Output<Node>& new_value, const std::shared_ptr<Variable>& variable);
void validate_and_infer_types() override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::string get_variable_id() const override
{
NGRAPH_CHECK(m_variable,
"Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id;
}
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs,
const EvaluationContext& evaluation_context) const override;
bool has_evaluate() const override;
bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values) override;
};
} // namespace v6
} // namespace op
} // namespace ngraph
std::string get_variable_id() const override {
NGRAPH_CHECK(m_variable, "Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id;
}
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs,
const EvaluationContext& evaluation_context) const override;
bool has_evaluate() const override;
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -8,38 +8,34 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise inverse tangent (arctan) operation.
///
class NGRAPH_API Atan : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an arctan operation.
Atan() = default;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise inverse tangent (arctan) operation.
///
class NGRAPH_API Atan : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an arctan operation.
Atan() = default;
/// \brief Constructs an arctan operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Atan(const Output<Node>& arg);
/// \brief Constructs an arctan operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Atan(const Output<Node>& arg);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Atan;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override {
return true;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Atan;
} // namespace op
} // namespace ngraph

View File

@ -8,38 +8,34 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Elementwise inverse hyperbolic tangent operation.
///
class NGRAPH_API Atanh : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Elementwise inverse hyperbolic tangent operation.
///
class NGRAPH_API Atanh : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an Atanh operation.
Atanh() = default;
/// \brief Constructs an Atanh operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Atanh(const Output<Node>& arg);
/// \brief Constructs an Atanh operation.
Atanh() = default;
/// \brief Constructs an Atanh operation.
///
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Atanh(const Output<Node>& arg);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
using v3::Atanh;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override {
return true;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
using v3::Atanh;
} // namespace op
} // namespace ngraph

View File

@ -7,89 +7,84 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Batched average pooling operation.
///
class NGRAPH_API AvgPool : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Batched average pooling operation.
///
class NGRAPH_API AvgPool : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a batched average pooling operation.
AvgPool() = default;
/// \brief Constructs a batched average pooling operation.
AvgPool() = default;
///
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param strides The strides.<br> `[n]`
/// \param pads_begin The beginning of padding shape.<br> `[n]`
/// \param pads_end The end of padding shape.<br> `[n]`
/// \param kernel The kernel shape.<br> `[n]`
/// \param exclude_pad If false then averages include padding elements, each
/// treated as the number zero. If true, padding
/// elements
/// are entirely ignored when computing averages.
/// \param rounding_type Whether to use ceiling or floor rounding type while
/// computing output shape.
/// \param auto_pad Padding type to use for additional padded dimensions
///
AvgPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad,
op::RoundingType rounding_type = op::RoundingType::FLOOR,
const PadType& auto_pad = op::PadType::EXPLICIT);
///
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param strides The strides.<br> `[n]`
/// \param pads_begin The beginning of padding shape.<br> `[n]`
/// \param pads_end The end of padding shape.<br> `[n]`
/// \param kernel The kernel shape.<br> `[n]`
/// \param exclude_pad If false then averages include padding elements, each
/// treated as the number zero. If true, padding
/// elements
/// are entirely ignored when computing averages.
/// \param rounding_type Whether to use ceiling or floor rounding type while
/// computing output shape.
/// \param auto_pad Padding type to use for additional padded dimensions
///
AvgPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad,
op::RoundingType rounding_type = op::RoundingType::FLOOR,
const PadType& auto_pad = op::PadType::EXPLICIT);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The kernel shape.
const Shape& get_kernel() const;
void set_kernel(const Shape& kernel);
/// \return The strides.
const Strides& get_strides() const;
void set_strides(const Strides& strides);
/// \return The beginning of padding shape.
const Shape& get_pads_begin() const;
void set_pads_begin(const Shape& pads_begin);
/// \return The end of padding shape.
const Shape& get_pads_end() const;
void set_pads_end(const Shape& pads_end);
bool get_exclude_pad() const;
void set_exclude_pad(bool exclude_pad);
/// \return The pad type for pooling.
const PadType& get_auto_pad() const;
void set_auto_pad(const PadType& auto_pad);
op::RoundingType get_rounding_type() const;
void set_rounding_type(op::RoundingType rounding_type);
/// \return The default value for AvgPool.
NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END
/// \return The kernel shape.
const Shape& get_kernel() const;
void set_kernel(const Shape& kernel);
/// \return The strides.
const Strides& get_strides() const;
void set_strides(const Strides& strides);
/// \return The beginning of padding shape.
const Shape& get_pads_begin() const;
void set_pads_begin(const Shape& pads_begin);
/// \return The end of padding shape.
const Shape& get_pads_end() const;
void set_pads_end(const Shape& pads_end);
bool get_exclude_pad() const;
void set_exclude_pad(bool exclude_pad);
/// \return The pad type for pooling.
const PadType& get_auto_pad() const;
void set_auto_pad(const PadType& auto_pad);
op::RoundingType get_rounding_type() const;
void set_rounding_type(op::RoundingType rounding_type);
/// \return The default value for AvgPool.
NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END
protected:
Shape m_kernel;
Strides m_strides;
Shape m_pads_begin;
Shape m_pads_end;
bool m_exclude_pad{true};
PadType m_auto_pad{PadType::EXPLICIT};
op::RoundingType m_rounding_type{op::RoundingType::FLOOR};
};
} // namespace v1
protected:
Shape m_kernel;
Strides m_strides;
Shape m_pads_begin;
Shape m_pads_end;
bool m_exclude_pad{true};
PadType m_auto_pad{PadType::EXPLICIT};
op::RoundingType m_rounding_type{op::RoundingType::FLOOR};
};
} // namespace v1
using v1::AvgPool;
} // namespace op
} // namespace ngraph
using v1::AvgPool;
} // namespace op
} // namespace ngraph

View File

@ -10,87 +10,87 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
class NGRAPH_API BatchNormInference : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon);
namespace ngraph {
namespace op {
namespace v0 {
class NGRAPH_API BatchNormInference : public Op {
public:
NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
double get_eps_value() const {
return m_epsilon;
}
void set_eps_value(double epsilon) {
m_epsilon = epsilon;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
double m_epsilon;
};
} // namespace v0
namespace v5
{
class NGRAPH_API BatchNormInference : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon);
double m_epsilon;
};
} // namespace v0
namespace v5 {
class NGRAPH_API BatchNormInference : public Op {
public:
NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
double get_eps_value() const {
return m_epsilon;
}
void set_eps_value(double epsilon) {
m_epsilon = epsilon;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
private:
static constexpr size_t INPUT_DATA = 0;
static constexpr size_t INPUT_GAMMA = 1;
static constexpr size_t INPUT_BETA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
private:
static constexpr size_t INPUT_DATA = 0;
static constexpr size_t INPUT_GAMMA = 1;
static constexpr size_t INPUT_BETA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
double m_epsilon;
};
} // namespace v5
} // namespace op
} // namespace ngraph
double m_epsilon;
};
} // namespace v5
} // namespace op
} // namespace ngraph

View File

@ -7,49 +7,43 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief BatchToSpace permutes data from the batch dimension of the data tensor into
/// spatial dimensions.
///
/// \note Values from the batch dimension are moved in spatial blocks dimensions.
///
/// Output node produces a tensor with shape:
/// `[batch / (block_shape[0] * block_shape[1] * ... * block_shape[N - 1]),
/// D_1 * block_shape[1] - crops_begin[1] - crops_end[1],
/// D_2 * block_shape[2] - crops_begin[2] - crops_end[2], ...,
/// D_{N - 1} * block_shape[N - 1] - crops_begin[N - 1] - crops_end[N - 1]`
/// of the same type as `data` input.
class NGRAPH_API BatchToSpace : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
BatchToSpace() = default;
/// \brief Constructs a BatchToSpace operation.
///
/// \param data Node producing the data tensor
/// \param block_shape The sizes of the block of values to be moved
/// \param crops_begin Specifies the amount to crop from the beginning along each
/// axis of `data` input
/// \param crops_end Specifies the amount to crop from the ending along each axis of
/// `data` input.
BatchToSpace(const Output<Node>& data,
const Output<Node>& block_shape,
const Output<Node>& crops_begin,
const Output<Node>& crops_end);
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief BatchToSpace permutes data from the batch dimension of the data tensor into
/// spatial dimensions.
///
/// \note Values from the batch dimension are moved in spatial blocks dimensions.
///
/// Output node produces a tensor with shape:
/// `[batch / (block_shape[0] * block_shape[1] * ... * block_shape[N - 1]),
/// D_1 * block_shape[1] - crops_begin[1] - crops_end[1],
/// D_2 * block_shape[2] - crops_begin[2] - crops_end[2], ...,
/// D_{N - 1} * block_shape[N - 1] - crops_begin[N - 1] - crops_end[N - 1]`
/// of the same type as `data` input.
class NGRAPH_API BatchToSpace : public Op {
public:
NGRAPH_RTTI_DECLARATION;
BatchToSpace() = default;
/// \brief Constructs a BatchToSpace operation.
///
/// \param data Node producing the data tensor
/// \param block_shape The sizes of the block of values to be moved
/// \param crops_begin Specifies the amount to crop from the beginning along each
/// axis of `data` input
/// \param crops_end Specifies the amount to crop from the ending along each axis of
/// `data` input.
BatchToSpace(const Output<Node>& data,
const Output<Node>& block_shape,
const Output<Node>& crops_begin,
const Output<Node>& crops_end);
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
};
} // namespace v1
} // namespace op
} // namespace ngraph
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -8,116 +8,136 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
class NGRAPH_API BinaryConvolution : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
class NGRAPH_API BinaryConvolution : public Op {
public:
NGRAPH_RTTI_DECLARATION;
enum class BinaryConvolutionMode
{
// Interpret input data and kernel values: 0 as -1, 1 as 1
XNOR_POPCOUNT
};
/// \brief Constructs a binary convolution operation.
BinaryConvolution() = default;
/// \brief Constructs a binary convolution operation.
/// \param data The node producing the input data batch tensor.
/// \param kernel The node producing the filters tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param dilations The dilations.
/// \param mode Defines how input tensor 0/1 values and weights 0/1 are interpreted.
/// \param pad_value Floating-point value used to fill pad area.
/// \param auto_pad The pad type for automatically computing padding sizes.
///
/// Output `[N, C_OUT, R1, ... Rf]`
BinaryConvolution(const Output<Node>& data,
const Output<Node>& kernel,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
BinaryConvolutionMode mode,
float pad_value,
const PadType& auto_pad = PadType::EXPLICIT);
BinaryConvolution(const Output<Node>& data,
const Output<Node>& kernel,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const std::string& mode,
float pad_value,
const PadType& auto_pad = PadType::EXPLICIT);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The strides.
const Strides& get_strides() const { return m_strides; }
void set_strides(const Strides& strides) { m_strides = strides; }
/// \return The dilations.
const Strides& get_dilations() const { return m_dilations; }
void set_dilations(const Strides& dilations) { m_dilations = dilations; }
/// \return The padding-below sizes (possibly negative).
const CoordinateDiff& get_pads_begin() const { return m_pads_begin; }
void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; }
/// \return The padding-above sizes (possibly negative).
const CoordinateDiff& get_pads_end() const { return m_pads_end; }
void set_adding_above(const CoordinateDiff& pads_end) { m_pads_end = pads_end; }
/// \return The pad type for convolution.
const PadType& get_auto_pad() const { return m_auto_pad; }
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; }
/// \return The mode of convolution.
const BinaryConvolutionMode& get_mode() const { return m_mode; }
void set_mode(const BinaryConvolutionMode& mode) { m_mode = mode; }
/// \return The pad value.
float get_pad_value() const { return m_pad_value; }
void set_pad_value(float pad_value) { m_pad_value = pad_value; }
protected:
BinaryConvolutionMode mode_from_string(const std::string& mode) const;
Strides m_strides;
Strides m_dilations;
CoordinateDiff m_pads_begin;
CoordinateDiff m_pads_end;
BinaryConvolutionMode m_mode;
float m_pad_value;
PadType m_auto_pad;
};
} // namespace v1
} // namespace op
NGRAPH_API
std::ostream& operator<<(std::ostream& s,
const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>
{
public:
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
enum class BinaryConvolutionMode {
// Interpret input data and kernel values: 0 as -1, 1 as 1
XNOR_POPCOUNT
};
} // namespace ngraph
/// \brief Constructs a binary convolution operation.
BinaryConvolution() = default;
/// \brief Constructs a binary convolution operation.
/// \param data The node producing the input data batch tensor.
/// \param kernel The node producing the filters tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param dilations The dilations.
/// \param mode Defines how input tensor 0/1 values and weights 0/1 are interpreted.
/// \param pad_value Floating-point value used to fill pad area.
/// \param auto_pad The pad type for automatically computing padding sizes.
///
/// Output `[N, C_OUT, R1, ... Rf]`
BinaryConvolution(const Output<Node>& data,
const Output<Node>& kernel,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
BinaryConvolutionMode mode,
float pad_value,
const PadType& auto_pad = PadType::EXPLICIT);
BinaryConvolution(const Output<Node>& data,
const Output<Node>& kernel,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const std::string& mode,
float pad_value,
const PadType& auto_pad = PadType::EXPLICIT);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The strides.
const Strides& get_strides() const {
return m_strides;
}
void set_strides(const Strides& strides) {
m_strides = strides;
}
/// \return The dilations.
const Strides& get_dilations() const {
return m_dilations;
}
void set_dilations(const Strides& dilations) {
m_dilations = dilations;
}
/// \return The padding-below sizes (possibly negative).
const CoordinateDiff& get_pads_begin() const {
return m_pads_begin;
}
void set_pads_begin(const CoordinateDiff& pads_begin) {
m_pads_begin = pads_begin;
}
/// \return The padding-above sizes (possibly negative).
const CoordinateDiff& get_pads_end() const {
return m_pads_end;
}
void set_adding_above(const CoordinateDiff& pads_end) {
m_pads_end = pads_end;
}
/// \return The pad type for convolution.
const PadType& get_auto_pad() const {
return m_auto_pad;
}
void set_auto_pad(const PadType& auto_pad) {
m_auto_pad = auto_pad;
}
/// \return The mode of convolution.
const BinaryConvolutionMode& get_mode() const {
return m_mode;
}
void set_mode(const BinaryConvolutionMode& mode) {
m_mode = mode;
}
/// \return The pad value.
float get_pad_value() const {
return m_pad_value;
}
void set_pad_value(float pad_value) {
m_pad_value = pad_value;
}
protected:
BinaryConvolutionMode mode_from_string(const std::string& mode) const;
Strides m_strides;
Strides m_dilations;
CoordinateDiff m_pads_begin;
CoordinateDiff m_pads_end;
BinaryConvolutionMode m_mode;
float m_pad_value;
PadType m_auto_pad;
};
} // namespace v1
} // namespace op
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode> {
public:
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>",
0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
} // namespace ngraph

View File

@ -9,135 +9,125 @@
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcast_base.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class NGRAPH_API Broadcast : public util::BroadcastBase
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class NGRAPH_API Broadcast : public util::BroadcastBase {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' should not be provided if mode other
/// than explicit (none) is used.
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' should not be provided if mode other
/// than explicit (none) is used.
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
// \return Broadcast Specification.
const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
{
m_mode = broadcast_spec;
}
// \return Broadcast Specification.
const BroadcastModeSpec& get_broadcast_spec() const {
return m_mode;
}
void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec) {
m_mode = broadcast_spec;
}
void validate_and_infer_types() override;
void validate_and_infer_types() override;
/// \return true and the AxisSet if broadcast axes can be fully determined.
std::pair<bool, AxisSet> get_broadcast_axes() const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
/// \return true and the AxisSet if broadcast axes can be fully determined.
std::pair<bool, AxisSet> get_broadcast_axes() const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
private:
bool broadcast_evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const;
};
} // namespace v3
private:
bool broadcast_evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
};
} // namespace v3
namespace v1
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class NGRAPH_API Broadcast : public util::BroadcastBase
{
public:
NGRAPH_RTTI_DECLARATION;
namespace v1 {
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class NGRAPH_API Broadcast : public util::BroadcastBase {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' is ignored if broadcast_spec is not
/// NONE
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' is ignored if broadcast_spec is not
/// NONE
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
{
m_broadcast_spec = broadcast_spec;
}
/// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const {
return m_broadcast_spec;
}
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec) {
m_broadcast_spec = broadcast_spec;
}
void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
AutoBroadcastSpec m_broadcast_spec;
};
} // namespace v1
} // namespace op
} // namespace ngraph
protected:
AutoBroadcastSpec m_broadcast_spec;
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -6,53 +6,53 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Operation that bucketizes the input based on boundaries
class NGRAPH_API Bucketize : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Operation that bucketizes the input based on boundaries
class NGRAPH_API Bucketize : public Op {
public:
NGRAPH_RTTI_DECLARATION;
Bucketize() = default;
/// \brief Constructs a Bucketize node
Bucketize() = default;
/// \brief Constructs a Bucketize node
/// \param data Input data to bucketize
/// \param buckets 1-D of sorted unique boundaries for buckets
/// \param output_type Output tensor type, "i64" or "i32", defaults to i64
/// \param with_right_bound indicates whether bucket includes the right or left
/// edge of interval. default true = includes right edge
Bucketize(const Output<Node>& data,
const Output<Node>& buckets,
const element::Type output_type = element::i64,
const bool with_right_bound = true);
/// \param data Input data to bucketize
/// \param buckets 1-D of sorted unique boundaries for buckets
/// \param output_type Output tensor type, "i64" or "i32", defaults to i64
/// \param with_right_bound indicates whether bucket includes the right or left
/// edge of interval. default true = includes right edge
Bucketize(const Output<Node>& data,
const Output<Node>& buckets,
const element::Type output_type = element::i64,
const bool with_right_bound = true);
virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
element::Type get_output_type() const { return m_output_type; }
void set_output_type(element::Type output_type) { m_output_type = output_type; }
// Overload collision with method on Node
using Node::set_output_type;
element::Type get_output_type() const {
return m_output_type;
}
void set_output_type(element::Type output_type) {
m_output_type = output_type;
}
// Overload collision with method on Node
using Node::set_output_type;
bool get_with_right_bound() const { return m_with_right_bound; }
void set_with_right_bound(bool with_right_bound)
{
m_with_right_bound = with_right_bound;
}
bool get_with_right_bound() const {
return m_with_right_bound;
}
void set_with_right_bound(bool with_right_bound) {
m_with_right_bound = with_right_bound;
}
private:
element::Type m_output_type;
bool m_with_right_bound;
};
} // namespace v3
using v3::Bucketize;
} // namespace op
} // namespace ngraph
private:
element::Type m_output_type;
bool m_with_right_bound;
};
} // namespace v3
using v3::Bucketize;
} // namespace op
} // namespace ngraph

View File

@ -6,32 +6,28 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise ceiling operation.
class NGRAPH_API Ceiling : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a ceiling operation.
Ceiling() = default;
/// \brief Constructs a ceiling operation.
///
/// \param arg Node that produces the input tensor.
Ceiling(const Output<Node>& arg);
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise ceiling operation.
class NGRAPH_API Ceiling : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a ceiling operation.
Ceiling() = default;
/// \brief Constructs a ceiling operation.
///
/// \param arg Node that produces the input tensor.
Ceiling(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor&) override { return true; }
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Ceiling;
} // namespace op
} // namespace ngraph
bool visit_attributes(AttributeVisitor&) override {
return true;
}
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Ceiling;
} // namespace op
} // namespace ngraph

View File

@ -7,48 +7,46 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Performs a clipping operation on all elements of the input node
///
/// All input values that are outside of the <min;max> range are set to 'min' or 'max'
/// depending on which side of the <min;max> range they are. The values that fall into
/// this range remain unchanged.
class NGRAPH_API Clamp : public ngraph::op::Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Performs a clipping operation on all elements of the input node
///
/// All input values that are outside of the <min;max> range are set to 'min' or 'max'
/// depending on which side of the <min;max> range they are. The values that fall into
/// this range remain unchanged.
class NGRAPH_API Clamp : public ngraph::op::Op {
public:
NGRAPH_RTTI_DECLARATION;
Clamp();
/// \brief Constructs a Clamp node.
///
/// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range
Clamp(const Output<Node>& data, const double min, const double max);
Clamp();
/// \brief Constructs a Clamp node.
///
/// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range
Clamp(const Output<Node>& data, const double min, const double max);
void validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
double get_min() const { return m_min; }
double get_max() const { return m_max; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
double get_min() const {
return m_min;
}
double get_max() const {
return m_max;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
private:
double m_min;
double m_max;
};
} // namespace v0
using v0::Clamp;
} // namespace op
} // namespace ngraph
private:
double m_min;
double m_max;
};
} // namespace v0
using v0::Clamp;
} // namespace op
} // namespace ngraph

View File

@ -8,60 +8,59 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Concatenation operation.
class NGRAPH_API Concat : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Concatenation operation.
class NGRAPH_API Concat : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a concatenation operation.
Concat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, int64_t axis);
/// \brief Constructs a concatenation operation.
Concat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, int64_t axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, int64_t axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, int64_t axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The concatenation axis.
int64_t get_concatenation_axis() const { return m_concat_axis; }
void set_concatenation_axis(int64_t concatenation_axis)
{
m_concat_axis = concatenation_axis;
}
/// \return The concatenation axis.
int64_t get_axis() const { return m_axis; }
void set_axis(int64_t axis) { m_axis = axis; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& output_values) const override;
bool evaluate_upper(const HostTensorVector& output_values) const override;
/// \return The concatenation axis.
int64_t get_concatenation_axis() const {
return m_concat_axis;
}
void set_concatenation_axis(int64_t concatenation_axis) {
m_concat_axis = concatenation_axis;
}
/// \return The concatenation axis.
int64_t get_axis() const {
return m_axis;
}
void set_axis(int64_t axis) {
m_axis = axis;
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& output_values) const override;
bool evaluate_upper(const HostTensorVector& output_values) const override;
protected:
/// \ brief m_axis stores default value for all iterations
int64_t m_axis;
/// \brief m_concat_axis stores m_axis plus the number of rank for each iteration
int64_t m_concat_axis = -1;
};
} // namespace v0
using v0::Concat;
} // namespace op
} // namespace ngraph
protected:
/// \ brief m_axis stores default value for all iterations
int64_t m_axis;
/// \brief m_concat_axis stores m_axis plus the number of rank for each iteration
int64_t m_concat_axis = -1;
};
} // namespace v0
using v0::Concat;
} // namespace op
} // namespace ngraph

File diff suppressed because it is too large Load Diff

View File

@ -7,51 +7,47 @@
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/host_tensor.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise type conversion operation.
class NGRAPH_API Convert : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise type conversion operation.
class NGRAPH_API Convert : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a conversion operation.
Convert() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param destination_type Element type for the output tensor.
Convert(const Output<Node>& arg, const ngraph::element::Type& destination_type);
/// \brief Constructs a conversion operation.
Convert() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param destination_type Element type for the output tensor.
Convert(const Output<Node>& arg, const ngraph::element::Type& destination_type);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
const element::Type& get_destination_type() const { return m_destination_type; }
void set_destination_type(const element::Type& destination_type)
{
m_destination_type = destination_type;
}
const element::Type& get_convert_element_type() const { return m_destination_type; }
void set_convert_element_type(const element::Type& destination_type)
{
m_destination_type = destination_type;
}
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const element::Type& get_destination_type() const {
return m_destination_type;
}
void set_destination_type(const element::Type& destination_type) {
m_destination_type = destination_type;
}
const element::Type& get_convert_element_type() const {
return m_destination_type;
}
void set_convert_element_type(const element::Type& destination_type) {
m_destination_type = destination_type;
}
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
protected:
ngraph::element::Type m_destination_type;
};
} // namespace v0
using v0::Convert;
} // namespace op
} // namespace ngraph
protected:
ngraph::element::Type m_destination_type;
};
} // namespace v0
using v0::Convert;
} // namespace op
} // namespace ngraph

View File

@ -6,37 +6,31 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise type conversion operation.
class NGRAPH_API ConvertLike : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Elementwise type conversion operation.
class NGRAPH_API ConvertLike : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a conversion operation.
ConvertLike() = default;
/// \brief Constructs a conversion operation.
/// \param data Node that produces the input tensor.
/// \param like Node which provides the target type information for the conversion.
ConvertLike(const Output<Node>& data, const Output<Node>& like);
/// \brief Constructs a conversion operation.
ConvertLike() = default;
/// \brief Constructs a conversion operation.
/// \param data Node that produces the input tensor.
/// \param like Node which provides the target type information for the conversion.
ConvertLike(const Output<Node>& data, const Output<Node>& like);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool constant_fold(OutputVector& output_values,
const OutputVector& input_values) override;
};
bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override;
};
} // namespace v1
} // namespace v1
} // namespace op
} // namespace op
} // namespace ngraph
} // namespace ngraph

View File

@ -8,91 +8,105 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Batched convolution operation, with optional window dilation and stride.
///
class NGRAPH_API Convolution : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Batched convolution operation, with optional window dilation and stride.
///
class NGRAPH_API Convolution : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a batched convolution operation.
Convolution() = default;
/// \brief Constructs a batched convolution operation.
///
/// \param data_batch The node producing the input data batch tensor.<br>
/// `[N, C_IN, D1, ... Df]`
/// \param filters The node producing the filters tensor.<br>
/// `[C_OUT, C_IN, F1, ... Ff]`
/// \param strides The strides.<br>
/// `[f]`
/// \param dilations The dilations.<br>
/// `[f]`
/// \param pads_begin The beginning of padding shape.<br>
/// `[f]`
/// \param pads_end The end of padding shape.<br>
/// `[f]`
/// \param auto_pad The pad type for automatically computing padding sizes.<br>
/// `[f]`
///
/// Output `[N, C_OUT, R1, ... Rf]`
///
Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT);
/// \brief Constructs a batched convolution operation.
Convolution() = default;
/// \brief Constructs a batched convolution operation.
///
/// \param data_batch The node producing the input data batch tensor.<br>
/// `[N, C_IN, D1, ... Df]`
/// \param filters The node producing the filters tensor.<br>
/// `[C_OUT, C_IN, F1, ... Ff]`
/// \param strides The strides.<br>
/// `[f]`
/// \param dilations The dilations.<br>
/// `[f]`
/// \param pads_begin The beginning of padding shape.<br>
/// `[f]`
/// \param pads_end The end of padding shape.<br>
/// `[f]`
/// \param auto_pad The pad type for automatically computing padding sizes.<br>
/// `[f]`
///
/// Output `[N, C_OUT, R1, ... Rf]`
///
Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The strides.
const Strides& get_strides() const { return m_strides; }
void set_strides(const Strides& strides) { m_strides = strides; }
/// \return The dilations.
const Strides& get_dilations() const { return m_dilations; }
void set_dilations(const Strides& dilations) { m_dilations = dilations; }
/// \return The padding-below sizes (possibly negative).
const CoordinateDiff& get_pads_begin() const { return m_pads_begin; }
void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; }
/// \return The padding-above sizes (possibly negative).
const CoordinateDiff& get_pads_end() const { return m_pads_end; }
void set_adding_above(const CoordinateDiff& pads_end) { m_pads_end = pads_end; }
/// \return The pad type for convolution.
const PadType& get_auto_pad() const { return m_auto_pad; }
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; }
/// \return The default value for Convolution.
NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END
/// \return The strides.
const Strides& get_strides() const {
return m_strides;
}
void set_strides(const Strides& strides) {
m_strides = strides;
}
/// \return The dilations.
const Strides& get_dilations() const {
return m_dilations;
}
void set_dilations(const Strides& dilations) {
m_dilations = dilations;
}
/// \return The padding-below sizes (possibly negative).
const CoordinateDiff& get_pads_begin() const {
return m_pads_begin;
}
void set_pads_begin(const CoordinateDiff& pads_begin) {
m_pads_begin = pads_begin;
}
/// \return The padding-above sizes (possibly negative).
const CoordinateDiff& get_pads_end() const {
return m_pads_end;
}
void set_adding_above(const CoordinateDiff& pads_end) {
m_pads_end = pads_end;
}
/// \return The pad type for convolution.
const PadType& get_auto_pad() const {
return m_auto_pad;
}
void set_auto_pad(const PadType& auto_pad) {
m_auto_pad = auto_pad;
}
/// \return The default value for Convolution.
NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END
protected:
Strides m_strides;
Strides m_dilations;
CoordinateDiff m_pads_begin;
CoordinateDiff m_pads_end;
PadType m_auto_pad;
};
protected:
Strides m_strides;
Strides m_dilations;
CoordinateDiff m_pads_begin;
CoordinateDiff m_pads_end;
PadType m_auto_pad;
};
/// \brief Data batch backprop for batched convolution operation.
class NGRAPH_API ConvolutionBackpropData : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Data batch backprop for batched convolution operation.
class NGRAPH_API ConvolutionBackpropData : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a batched-convolution data batch-backprop operation.
ConvolutionBackpropData() = default;
// clang-format off
/// \brief Constructs a batched-convolution data batch-backprop operation.
ConvolutionBackpropData() = default;
// clang-format off
//
// \brief Constructs a batched-convolution data batch-backprop operation.
//

View File

@ -6,33 +6,27 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise cosine operation.
class NGRAPH_API Cos : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise cosine operation.
class NGRAPH_API Cos : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a cosine operation.
Cos() = default;
/// \brief Constructs a cosine operation.
///
/// \param arg Node that produces the input tensor.
Cos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
/// \brief Constructs a cosine operation.
Cos() = default;
/// \brief Constructs a cosine operation.
///
/// \param arg Node that produces the input tensor.
Cos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Cos;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Cos;
} // namespace op
} // namespace ngraph

View File

@ -6,33 +6,27 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise hyperbolic cosine (cosh) operation.
class NGRAPH_API Cosh : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise hyperbolic cosine (cosh) operation.
class NGRAPH_API Cosh : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a hyperbolic cosine operation.
Cosh() = default;
/// \brief Constructs a hyperbolic cosine operation.
///
/// \param arg Node that produces the input tensor.
Cosh(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
/// \brief Constructs a hyperbolic cosine operation.
Cosh() = default;
/// \brief Constructs a hyperbolic cosine operation.
///
/// \param arg Node that produces the input tensor.
Cosh(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Cosh;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Cosh;
} // namespace op
} // namespace ngraph

View File

@ -6,38 +6,33 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
class NGRAPH_API CTCGreedyDecoder : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
class NGRAPH_API CTCGreedyDecoder : public Op {
public:
NGRAPH_RTTI_DECLARATION;
CTCGreedyDecoder() = default;
/// \brief Constructs a CTCGreedyDecoder operation
///
/// \param input Logits on which greedy decoding is performed
/// \param seq_len Sequence lengths
/// \param ctc_merge_repeated Whether to merge repeated labels
CTCGreedyDecoder(const Output<Node>& input,
const Output<Node>& seq_len,
const bool ctc_merge_repeated);
CTCGreedyDecoder() = default;
/// \brief Constructs a CTCGreedyDecoder operation
///
/// \param input Logits on which greedy decoding is performed
/// \param seq_len Sequence lengths
/// \param ctc_merge_repeated Whether to merge repeated labels
CTCGreedyDecoder(const Output<Node>& input, const Output<Node>& seq_len, const bool ctc_merge_repeated);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_ctc_merge_repeated() const { return m_ctc_merge_repeated; }
bool get_ctc_merge_repeated() const {
return m_ctc_merge_repeated;
}
private:
bool m_ctc_merge_repeated;
};
} // namespace v0
using v0::CTCGreedyDecoder;
} // namespace op
} // namespace ngraph
private:
bool m_ctc_merge_repeated;
};
} // namespace v0
using v0::CTCGreedyDecoder;
} // namespace op
} // namespace ngraph

View File

@ -6,99 +6,95 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief Operator performing CTCGreedyDecoder
///
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
CTCGreedyDecoderSeqLen() = default;
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
///
/// \param input 3-D tensor of logits on which greedy decoding is
/// performed
/// \param seq_len 1-D tensor of sequence lengths
/// \param merge_repeated Whether to merge repeated labels
/// \param classes_index_type Specifies the output classes_index tensor type
/// \param sequence_length_type Specifies the output sequence_length tensor type
CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const bool merge_repeated = true,
const element::Type& classes_index_type = element::i32,
const element::Type& sequence_length_type = element::i32);
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
///
/// \param input 3-D tensor of logits on which greedy decoding is
/// performed
/// \param seq_len 1-D tensor of sequence lengths
/// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
/// blank index
/// \param merge_repeated Whether to merge repeated labels
/// \param classes_index_type Specifies the output classes_index tensor type
/// \param sequence_length_type Specifies the output sequence_length tensor type
CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const Output<Node>& blank_index,
const bool merge_repeated = true,
const element::Type& classes_index_type = element::i32,
const element::Type& sequence_length_type = element::i32);
namespace ngraph {
namespace op {
namespace v6 {
/// \brief Operator performing CTCGreedyDecoder
///
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op {
public:
NGRAPH_RTTI_DECLARATION;
CTCGreedyDecoderSeqLen() = default;
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
///
/// \param input 3-D tensor of logits on which greedy decoding is
/// performed
/// \param seq_len 1-D tensor of sequence lengths
/// \param merge_repeated Whether to merge repeated labels
/// \param classes_index_type Specifies the output classes_index tensor type
/// \param sequence_length_type Specifies the output sequence_length tensor type
CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const bool merge_repeated = true,
const element::Type& classes_index_type = element::i32,
const element::Type& sequence_length_type = element::i32);
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
///
/// \param input 3-D tensor of logits on which greedy decoding is
/// performed
/// \param seq_len 1-D tensor of sequence lengths
/// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
/// blank index
/// \param merge_repeated Whether to merge repeated labels
/// \param classes_index_type Specifies the output classes_index tensor type
/// \param sequence_length_type Specifies the output sequence_length tensor type
CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const Output<Node>& blank_index,
const bool merge_repeated = true,
const element::Type& classes_index_type = element::i32,
const element::Type& sequence_length_type = element::i32);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Get merge_repeated attribute
///
/// \return Current value of merge_repeated attribute
///
bool get_merge_repeated() const { return m_merge_repeated; }
/// \brief Get classes_index_type attribute
///
/// \return Current value of classes_index_type attribute
///
const element::Type& get_classes_index_type() const { return m_classes_index_type; }
/// \brief Set classes_index_type attribute
///
/// \param classes_index_type Type of classes_index
///
void set_classes_index_type(const element::Type& classes_index_type)
{
m_classes_index_type = classes_index_type;
validate_and_infer_types();
}
/// \brief Get merge_repeated attribute
///
/// \return Current value of merge_repeated attribute
///
bool get_merge_repeated() const {
return m_merge_repeated;
}
/// \brief Get classes_index_type attribute
///
/// \return Current value of classes_index_type attribute
///
const element::Type& get_classes_index_type() const {
return m_classes_index_type;
}
/// \brief Set classes_index_type attribute
///
/// \param classes_index_type Type of classes_index
///
void set_classes_index_type(const element::Type& classes_index_type) {
m_classes_index_type = classes_index_type;
validate_and_infer_types();
}
/// \brief Get sequence_length_type attribute
///
/// \return Current value of sequence_length_type attribute
///
const element::Type& get_sequence_length_type() const
{
return m_sequence_length_type;
}
/// \brief Get sequence_length_type attribute
///
/// \return Current value of sequence_length_type attribute
///
const element::Type& get_sequence_length_type() const {
return m_sequence_length_type;
}
/// \brief Set sequence_length_type attribute
///
/// \param sequence_length_type Type of sequence length
///
void set_sequence_length_type(const element::Type& sequence_length_type)
{
m_sequence_length_type = sequence_length_type;
validate_and_infer_types();
}
/// \brief Set sequence_length_type attribute
///
/// \param sequence_length_type Type of sequence length
///
void set_sequence_length_type(const element::Type& sequence_length_type) {
m_sequence_length_type = sequence_length_type;
validate_and_infer_types();
}
private:
bool m_merge_repeated;
element::Type m_classes_index_type{element::i32};
element::Type m_sequence_length_type{element::i32};
};
} // namespace v6
} // namespace op
} // namespace ngraph
private:
bool m_merge_repeated;
element::Type m_classes_index_type{element::i32};
element::Type m_sequence_length_type{element::i32};
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -6,68 +6,68 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v4
{
class NGRAPH_API CTCLoss : public Op
{
public:
static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
CTCLoss() = default;
/// \brief Constructs a CTCLoss operation
///
/// \param logits 3-D tensor of logits
/// \param logit_length 1-D tensor of length for each object from
/// a batch
/// \param labels 2-D tensor of labels for which likelyhood
/// is estimated using logist
/// \param label_length 1-D tensor of length for each label
/// sequence
/// \param blank_index Scalar used to mark a blank index
/// \param preprocess_collapse_repeated Flag for preprocessing labels before loss
/// calculation
/// \param ctc_merge_repeated Flag for merging repeated characters in a
/// potential alignment
/// \param unique Flag to find unique elements in a target
/// before matching with alignment
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const bool preprocess_collapse_repeated = false,
const bool ctc_merge_repeated = true,
const bool unique = false);
namespace ngraph {
namespace op {
namespace v4 {
class NGRAPH_API CTCLoss : public Op {
public:
static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
CTCLoss() = default;
/// \brief Constructs a CTCLoss operation
///
/// \param logits 3-D tensor of logits
/// \param logit_length 1-D tensor of length for each object from
/// a batch
/// \param labels 2-D tensor of labels for which likelyhood
/// is estimated using logist
/// \param label_length 1-D tensor of length for each label
/// sequence
/// \param blank_index Scalar used to mark a blank index
/// \param preprocess_collapse_repeated Flag for preprocessing labels before loss
/// calculation
/// \param ctc_merge_repeated Flag for merging repeated characters in a
/// potential alignment
/// \param unique Flag to find unique elements in a target
/// before matching with alignment
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const bool preprocess_collapse_repeated = false,
const bool ctc_merge_repeated = true,
const bool unique = false);
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const Output<Node>& blank_index,
const bool preprocess_collapse_repeated = false,
const bool ctc_merge_repeated = true,
const bool unique = false);
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const Output<Node>& blank_index,
const bool preprocess_collapse_repeated = false,
const bool ctc_merge_repeated = true,
const bool unique = false);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_preprocess_collapse_repeated() const
{
return preprocess_collapse_repeated_;
}
bool get_ctc_merge_repeated() const { return ctc_merge_repeated_; }
bool get_unique() const { return unique_; }
bool get_preprocess_collapse_repeated() const {
return preprocess_collapse_repeated_;
}
bool get_ctc_merge_repeated() const {
return ctc_merge_repeated_;
}
bool get_unique() const {
return unique_;
}
private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
bool unique_;
};
} // namespace v4
} // namespace op
} // namespace ngraph
private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
bool unique_;
};
} // namespace v4
} // namespace op
} // namespace ngraph

View File

@ -7,101 +7,95 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Tensor cumulative sum operation.
///
/// Compute the cumulative sum of the input tensor along the axis specified.
///
/// ## Parameters
///
/// | | Description |
/// | -------------------- |
/// --------------------------------------------------------------------------------------------------|
/// | `exclusive` | If set to 1 will return exclusive sum in which the top
/// element
/// is not included. |
/// | | In other terms, if set to 1, the j-th output element
/// would be
/// the
/// sum of the first (j-1) elements.|
/// | | Otherwise, it would be the sum of the first j elements.
/// |
///
/// | | Description |
/// | -------------------- | -------------------------------------------------- |
/// | `reverse` | if set to 1, performs the sum in reverse direction |
///
/// ## Inputs
///
/// | | Description |
/// | ----- | ------------------------------------------------------ |
/// | `arg` | An input tensor of any shape and numeric element type. |
///
/// | | Description |
/// | ----- |
/// ------------------------------------------------------------------------------------------------|
/// | `axis`| zero dimension tensor specifying axis position along which cumulative sum
/// must
/// be performed. |
///
/// ## Output
///
/// | Description |
/// |
/// ------------------------------------------------------------------------------------|
/// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements
/// |
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Tensor cumulative sum operation.
///
/// Compute the cumulative sum of the input tensor along the axis specified.
///
/// ## Parameters
///
/// | | Description |
/// | -------------------- |
/// --------------------------------------------------------------------------------------------------|
/// | `exclusive` | If set to 1 will return exclusive sum in which the top
/// element
/// is not included. |
/// | | In other terms, if set to 1, the j-th output element
/// would be
/// the
/// sum of the first (j-1) elements.|
/// | | Otherwise, it would be the sum of the first j elements.
/// |
///
/// | | Description |
/// | -------------------- | -------------------------------------------------- |
/// | `reverse` | if set to 1, performs the sum in reverse direction |
///
/// ## Inputs
///
/// | | Description |
/// | ----- | ------------------------------------------------------ |
/// | `arg` | An input tensor of any shape and numeric element type. |
///
/// | | Description |
/// | ----- |
/// ------------------------------------------------------------------------------------------------|
/// | `axis`| zero dimension tensor specifying axis position along which cumulative sum
/// must
/// be performed. |
///
/// ## Output
///
/// | Description |
/// |
/// ------------------------------------------------------------------------------------|
/// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements
/// |
class NGRAPH_API CumSum : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
class NGRAPH_API CumSum : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a cumulative summation operation.
CumSum() = default;
/// \brief Constructs a cumulative summation operation.
CumSum() = default;
/// \brief Constructs a cumulative summation operation.
///
/// \param arg The tensor to be summed.
/// \param axis zero dimension tensor specifying axis position along which
/// cumulative sum must be performed
/// \param exclusive if set to true, the top element is not included
/// \param reverse if set to true, will perform the sums in reverse direction
CumSum(const Output<Node>& arg,
const Output<Node>& axis,
const bool exclusive = false,
const bool reverse = false);
/// \brief Constructs a cumulative summation operation.
///
/// \param arg The tensor to be summed.
/// \param axis zero dimension tensor specifying axis position along which
/// cumulative sum must be performed
/// \param exclusive if set to true, the top element is not included
/// \param reverse if set to true, will perform the sums in reverse direction
CumSum(const Output<Node>& arg, const Output<Node>& axis, const bool exclusive = false, const bool reverse = false);
/// \brief Constructs a cumulative summation operation with axis = 0
///
/// \param arg The tensor to be summed
CumSum(const Output<Node>& arg,
const bool exclusive = false,
const bool reverse = false);
/// \brief Constructs a cumulative summation operation with axis = 0
///
/// \param arg The tensor to be summed
CumSum(const Output<Node>& arg, const bool exclusive = false, const bool reverse = false);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
/// \return The default value for CumSum.
NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END
bool is_exclusive() const { return m_exclusive; }
bool is_reverse() const { return m_reverse; }
/// \return The default value for CumSum.
NGRAPH_SUPPRESS_DEPRECATED_START
virtual std::shared_ptr<Node> get_default_value() const override;
NGRAPH_SUPPRESS_DEPRECATED_END
bool is_exclusive() const {
return m_exclusive;
}
bool is_reverse() const {
return m_reverse;
}
private:
bool m_exclusive;
bool m_reverse;
};
} // namespace v0
using v0::CumSum;
} // namespace op
} // namespace ngraph
private:
bool m_exclusive;
bool m_reverse;
};
} // namespace v0
using v0::CumSum;
} // namespace op
} // namespace ngraph

View File

@ -9,175 +9,167 @@
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/deformable_convolution_base.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief DeformableConvolution operation.
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief DeformableConvolution operation.
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a conversion operation.
DeformableConvolution() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param offsets Node producing the deformable values tensor.
/// \param filters Node producing the filters(kernels) tensor with OIZYX
/// layout.
/// \param strides Convolution strides.
/// \param pads_begin Amount of padding to be added to the beginning along
/// each axis. For example in case of a 2D input the value
/// of (1, 2) means that 1 element will be added to the
/// top and 2 elements to the left.
/// \param pads_end Amount of padding to be added to the end along each
/// axis.
/// \param dilations The distance in width and height between the weights
/// in the filters tensor.
/// \param auto_pad Specifies how the automatic calculation of padding
/// should be done.
/// \param group The number of groups which both output and input
/// should be split into.
/// \param deformable_group The number of groups which deformable values and
/// output should be split into along the channel axis.
DeformableConvolution(const Output<Node>& arg,
const Output<Node>& offsets,
const Output<Node>& filters,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT,
const int64_t group = 1,
const int64_t deformable_group = 1);
/// \brief Constructs a conversion operation.
DeformableConvolution() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param offsets Node producing the deformable values tensor.
/// \param filters Node producing the filters(kernels) tensor with OIZYX
/// layout.
/// \param strides Convolution strides.
/// \param pads_begin Amount of padding to be added to the beginning along
/// each axis. For example in case of a 2D input the value
/// of (1, 2) means that 1 element will be added to the
/// top and 2 elements to the left.
/// \param pads_end Amount of padding to be added to the end along each
/// axis.
/// \param dilations The distance in width and height between the weights
/// in the filters tensor.
/// \param auto_pad Specifies how the automatic calculation of padding
/// should be done.
/// \param group The number of groups which both output and input
/// should be split into.
/// \param deformable_group The number of groups which deformable values and
/// output should be split into along the channel axis.
DeformableConvolution(const Output<Node>& arg,
const Output<Node>& offsets,
const Output<Node>& filters,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT,
const int64_t group = 1,
const int64_t deformable_group = 1);
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v1
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v1
namespace v8
{
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase
{
public:
NGRAPH_RTTI_DECLARATION;
namespace v8 {
class NGRAPH_API DeformableConvolution : public op::util::DeformableConvolutionBase {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a conversion operation.
DeformableConvolution() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param offsets Node producing the deformable values tensor.
/// \param filters Node producing the filters(kernels) tensor with OIZYX
/// layout.
/// \param strides Convolution strides.
/// \param pads_begin Amount of padding to be added to the beginning along
/// each axis. For example in case of a 2D input the value
/// of (1, 2) means that 1 element will be added to the
/// top and 2 elements to the left.
/// \param pads_end Amount of padding to be added to the end along each
/// axis.
/// \param dilations The distance in width and height between the weights
/// in the filters tensor.
/// \param auto_pad Specifies how the automatic calculation of padding
/// should be done.
/// \param group The number of groups which both output and input
/// should be split into.
/// \param deformable_group The number of groups which deformable values and
/// output should be split into along the channel axis.
/// \param bilinear_interpolation_pad
/// The flag that determines the mode of bilinear
/// interpolation execution.
/// If the flag is `true` and the sampling location is
/// within one pixel outside of the feature map boundary,
/// then bilinear interpolation is performed on the zero
/// padded feature map. If the flag is `false` and the
/// sampling location is within one pixel outside of the
/// feature map boundary, then the sampling location
/// shifts to the inner boundary of the feature map.`
DeformableConvolution(const Output<Node>& arg,
const Output<Node>& offsets,
const Output<Node>& filters,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT,
const int64_t group = 1,
const int64_t deformable_group = 1,
const bool bilinear_interpolation_pad = false);
/// \brief Constructs a conversion operation.
DeformableConvolution() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param offsets Node producing the deformable values tensor.
/// \param filters Node producing the filters(kernels) tensor with OIZYX
/// layout.
/// \param strides Convolution strides.
/// \param pads_begin Amount of padding to be added to the beginning along
/// each axis. For example in case of a 2D input the value
/// of (1, 2) means that 1 element will be added to the
/// top and 2 elements to the left.
/// \param pads_end Amount of padding to be added to the end along each
/// axis.
/// \param dilations The distance in width and height between the weights
/// in the filters tensor.
/// \param auto_pad Specifies how the automatic calculation of padding
/// should be done.
/// \param group The number of groups which both output and input
/// should be split into.
/// \param deformable_group The number of groups which deformable values and
/// output should be split into along the channel axis.
/// \param bilinear_interpolation_pad
/// The flag that determines the mode of bilinear
/// interpolation execution.
/// If the flag is `true` and the sampling location is
/// within one pixel outside of the feature map boundary,
/// then bilinear interpolation is performed on the zero
/// padded feature map. If the flag is `false` and the
/// sampling location is within one pixel outside of the
/// feature map boundary, then the sampling location
/// shifts to the inner boundary of the feature map.`
DeformableConvolution(const Output<Node>& arg,
const Output<Node>& offsets,
const Output<Node>& filters,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT,
const int64_t group = 1,
const int64_t deformable_group = 1,
const bool bilinear_interpolation_pad = false);
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param offsets Node producing the deformable values tensor.
/// \param filters Node producing the filters(kernels) tensor with OIZYX
/// layout.
/// \param mask Node producing the mask(mask) tensor.
/// \param strides Convolution strides.
/// \param pads_begin Amount of padding to be added to the beginning along
/// each axis. For example in case of a 2D input the value
/// of (1, 2) means that 1 element will be added to the
/// top and 2 elements to the left.
/// \param pads_end Amount of padding to be added to the end along each
/// axis.
/// \param dilations The distance in width and height between the weights
/// in the filters tensor.
/// \param auto_pad Specifies how the automatic calculation of padding
/// should be done.
/// \param group The number of groups which both output and input
/// should be split into.
/// \param deformable_group The number of groups which deformable values and
/// output should be split into along the channel axis.
/// \param bilinear_interpolation_pad
/// The flag that determines the mode of bilinear
/// interpolation execution.
/// If the flag is `true` and the sampling location is
/// within one pixel outside of the feature map boundary,
/// then bilinear interpolation is performed on the zero
/// padded feature map. If the flag is `false` and the
/// sampling location is within one pixel outside of the
/// feature map boundary, then the sampling location
/// shifts to the inner boundary of the feature map.
DeformableConvolution(const Output<Node>& arg,
const Output<Node>& offsets,
const Output<Node>& filters,
const Output<Node>& mask,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT,
const int64_t group = 1,
const int64_t deformable_group = 1,
const bool bilinear_interpolation_pad = false);
bool visit_attributes(AttributeVisitor& visitor) override;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param offsets Node producing the deformable values tensor.
/// \param filters Node producing the filters(kernels) tensor with OIZYX
/// layout.
/// \param mask Node producing the mask(mask) tensor.
/// \param strides Convolution strides.
/// \param pads_begin Amount of padding to be added to the beginning along
/// each axis. For example in case of a 2D input the value
/// of (1, 2) means that 1 element will be added to the
/// top and 2 elements to the left.
/// \param pads_end Amount of padding to be added to the end along each
/// axis.
/// \param dilations The distance in width and height between the weights
/// in the filters tensor.
/// \param auto_pad Specifies how the automatic calculation of padding
/// should be done.
/// \param group The number of groups which both output and input
/// should be split into.
/// \param deformable_group The number of groups which deformable values and
/// output should be split into along the channel axis.
/// \param bilinear_interpolation_pad
/// The flag that determines the mode of bilinear
/// interpolation execution.
/// If the flag is `true` and the sampling location is
/// within one pixel outside of the feature map boundary,
/// then bilinear interpolation is performed on the zero
/// padded feature map. If the flag is `false` and the
/// sampling location is within one pixel outside of the
/// feature map boundary, then the sampling location
/// shifts to the inner boundary of the feature map.
DeformableConvolution(const Output<Node>& arg,
const Output<Node>& offsets,
const Output<Node>& filters,
const Output<Node>& mask,
const Strides& strides,
const CoordinateDiff& pads_begin,
const CoordinateDiff& pads_end,
const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT,
const int64_t group = 1,
const int64_t deformable_group = 1,
const bool bilinear_interpolation_pad = false);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
bool has_evaluate() const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_bilinear_interpolation_pad() const { return m_bilinear_interpolation_pad; }
bool get_bilinear_interpolation_pad() const {
return m_bilinear_interpolation_pad;
}
void set_bilinear_interpolation_pad(const bool bilinear_interpolation_pad)
{
m_bilinear_interpolation_pad = bilinear_interpolation_pad;
}
void set_bilinear_interpolation_pad(const bool bilinear_interpolation_pad) {
m_bilinear_interpolation_pad = bilinear_interpolation_pad;
}
private:
bool m_bilinear_interpolation_pad;
};
} // namespace v8
} // namespace op
} // namespace ngraph
private:
bool m_bilinear_interpolation_pad;
};
} // namespace v8
} // namespace op
} // namespace ngraph

View File

@ -6,91 +6,102 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
class NGRAPH_API DeformablePSROIPooling : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
class NGRAPH_API DeformablePSROIPooling : public Op {
public:
NGRAPH_RTTI_DECLARATION;
DeformablePSROIPooling() = default;
/// \brief Constructs a DeformablePSROIPooling operation
///
/// \param input Input tensor with position sensitive score maps
/// \param coords Input tensor with list of five element tuples
/// describing ROI coordinates
/// \param offsets Input tensor with transformation values
/// \param output_dim Pooled output channel number
/// \param group_size Number of horizontal bins per row to divide ROI area,
/// it defines output width and height
/// \param spatial_scale Multiplicative spatial scale factor to translate ROI
/// coordinates from their input scale to the scale used when
/// pooling
/// \param mode Specifies mode for pooling.
/// \param spatial_bins_x Specifies numbers of bins to divide ROI single
/// bin over width
/// \param spatial_bins_y Specifies numbers of bins to divide ROI single
/// bin over height
/// \param no_trans The flag that specifies whenever third input exists
/// and contains transformation (offset) values
/// \param trans_std The value that all transformation (offset) values are
/// multiplied with
/// \param part_size The number of parts the output tensor spatial dimensions
/// are divided into. Basically it is the height
/// and width of the third input
DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const Output<Node>& offsets,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size = 1,
const std::string mode = "bilinear_deformable",
int64_t spatial_bins_x = 1,
int64_t spatial_bins_y = 1,
float trans_std = 1,
int64_t part_size = 1);
DeformablePSROIPooling() = default;
/// \brief Constructs a DeformablePSROIPooling operation
///
/// \param input Input tensor with position sensitive score maps
/// \param coords Input tensor with list of five element tuples
/// describing ROI coordinates
/// \param offsets Input tensor with transformation values
/// \param output_dim Pooled output channel number
/// \param group_size Number of horizontal bins per row to divide ROI area,
/// it defines output width and height
/// \param spatial_scale Multiplicative spatial scale factor to translate ROI
/// coordinates from their input scale to the scale used when
/// pooling
/// \param mode Specifies mode for pooling.
/// \param spatial_bins_x Specifies numbers of bins to divide ROI single
/// bin over width
/// \param spatial_bins_y Specifies numbers of bins to divide ROI single
/// bin over height
/// \param no_trans The flag that specifies whenever third input exists
/// and contains transformation (offset) values
/// \param trans_std The value that all transformation (offset) values are
/// multiplied with
/// \param part_size The number of parts the output tensor spatial dimensions
/// are divided into. Basically it is the height
/// and width of the third input
DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const Output<Node>& offsets,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size = 1,
const std::string mode = "bilinear_deformable",
int64_t spatial_bins_x = 1,
int64_t spatial_bins_y = 1,
float trans_std = 1,
int64_t part_size = 1);
DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size = 1,
const std::string mode = "bilinear_deformable",
int64_t spatial_bins_x = 1,
int64_t spatial_bins_y = 1,
float trans_std = 1,
int64_t part_size = 1);
DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size = 1,
const std::string mode = "bilinear_deformable",
int64_t spatial_bins_x = 1,
int64_t spatial_bins_y = 1,
float trans_std = 1,
int64_t part_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_output_dim() const { return m_output_dim; }
int64_t get_group_size() const { return m_group_size; }
float get_spatial_scale() const { return m_spatial_scale; }
const std::string& get_mode() const { return m_mode; }
int64_t get_spatial_bins_x() const { return m_spatial_bins_x; }
int64_t get_spatial_bins_y() const { return m_spatial_bins_y; }
float get_trans_std() const { return m_trans_std; }
int64_t get_part_size() const { return m_part_size; }
int64_t get_output_dim() const {
return m_output_dim;
}
int64_t get_group_size() const {
return m_group_size;
}
float get_spatial_scale() const {
return m_spatial_scale;
}
const std::string& get_mode() const {
return m_mode;
}
int64_t get_spatial_bins_x() const {
return m_spatial_bins_x;
}
int64_t get_spatial_bins_y() const {
return m_spatial_bins_y;
}
float get_trans_std() const {
return m_trans_std;
}
int64_t get_part_size() const {
return m_part_size;
}
private:
int64_t m_output_dim;
float m_spatial_scale;
int64_t m_group_size = 1;
std::string m_mode = "bilinear_deformable";
int64_t m_spatial_bins_x = 1;
int64_t m_spatial_bins_y = 1;
float m_trans_std = 1.f;
int64_t m_part_size = 1;
};
} // namespace v1
} // namespace op
} // namespace ngraph
private:
int64_t m_output_dim;
float m_spatial_scale;
int64_t m_group_size = 1;
std::string m_mode = "bilinear_deformable";
int64_t m_spatial_bins_x = 1;
int64_t m_spatial_bins_y = 1;
float m_trans_std = 1.f;
int64_t m_part_size = 1;
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -7,81 +7,72 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief DepthToSpace permutes data from the depth dimension of the input blob into
/// spatial dimensions.
///
/// \note Values from the depth dimension (assuming NCHW layout) are moved in
/// spatial blocks to the height and width dimensions.
///
/// Output node produces a tensor with shape:
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
class NGRAPH_API DepthToSpace : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief DepthToSpace permutes data from the depth dimension of the input blob into
/// spatial dimensions.
///
/// \note Values from the depth dimension (assuming NCHW layout) are moved in
/// spatial blocks to the height and width dimensions.
///
/// Output node produces a tensor with shape:
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
class NGRAPH_API DepthToSpace : public Op {
public:
NGRAPH_RTTI_DECLARATION;
enum class DepthToSpaceMode
{
// The input depth is divided to [block_size, ..., block_size, new_depth]
BLOCKS_FIRST,
// The input depth is divided to [new_depth, block_size, ..., block_size]
DEPTH_FIRST
};
DepthToSpace() = default;
/// \brief Constructs a DepthToSpace operation.
///
/// \param data Node producing the input tensor
/// \param mode Specifies how the input depth dimension is split to block
/// coordinates
/// \param block_size The size of the block of values to be moved
DepthToSpace(const Output<Node>& data,
const DepthToSpaceMode& mode,
std::size_t block_size = 1);
DepthToSpace(const Output<Node>& data,
const std::string& mode,
std::size_t block_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
std::size_t get_block_size() const { return m_blocksize; }
DepthToSpaceMode get_mode() const { return m_mode; }
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
std::size_t m_blocksize;
DepthToSpaceMode m_mode;
};
} // namespace v0
using v0::DepthToSpace;
} // namespace op
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const op::v0::DepthToSpace::DepthToSpaceMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>
: public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>
{
public:
AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value)
: EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
enum class DepthToSpaceMode {
// The input depth is divided to [block_size, ..., block_size, new_depth]
BLOCKS_FIRST,
// The input depth is divided to [new_depth, block_size, ..., block_size]
DEPTH_FIRST
};
} // namespace ngraph
DepthToSpace() = default;
/// \brief Constructs a DepthToSpace operation.
///
/// \param data Node producing the input tensor
/// \param mode Specifies how the input depth dimension is split to block
/// coordinates
/// \param block_size The size of the block of values to be moved
DepthToSpace(const Output<Node>& data, const DepthToSpaceMode& mode, std::size_t block_size = 1);
DepthToSpace(const Output<Node>& data, const std::string& mode, std::size_t block_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
std::size_t get_block_size() const {
return m_blocksize;
}
DepthToSpaceMode get_mode() const {
return m_mode;
}
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
std::size_t m_blocksize;
DepthToSpaceMode m_mode;
};
} // namespace v0
using v0::DepthToSpace;
} // namespace op
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const op::v0::DepthToSpace::DepthToSpaceMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>
: public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode> {
public:
AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value)
: EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value) {}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0};
const DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
};
} // namespace ngraph

View File

@ -6,78 +6,74 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
struct DetectionOutputAttrs
{
int num_classes;
int background_label_id = 0;
int top_k = -1;
bool variance_encoded_in_target = false;
std::vector<int> keep_top_k;
std::string code_type = std::string{"caffe.PriorBoxParameter.CORNER"};
bool share_location = true;
float nms_threshold;
float confidence_threshold = 0;
bool clip_after_nms = false;
bool clip_before_nms = false;
bool decrease_label_id = false;
bool normalized = false;
size_t input_height = 1;
size_t input_width = 1;
float objectness_score = 0;
};
namespace ngraph {
namespace op {
struct DetectionOutputAttrs {
int num_classes;
int background_label_id = 0;
int top_k = -1;
bool variance_encoded_in_target = false;
std::vector<int> keep_top_k;
std::string code_type = std::string{"caffe.PriorBoxParameter.CORNER"};
bool share_location = true;
float nms_threshold;
float confidence_threshold = 0;
bool clip_after_nms = false;
bool clip_before_nms = false;
bool decrease_label_id = false;
bool normalized = false;
size_t input_height = 1;
size_t input_width = 1;
float objectness_score = 0;
};
namespace v0
{
/// \brief Layer which performs non-max suppression to
/// generate detection output using location and confidence predictions
class NGRAPH_API DetectionOutput : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace v0 {
/// \brief Layer which performs non-max suppression to
/// generate detection output using location and confidence predictions
class NGRAPH_API DetectionOutput : public Op {
public:
NGRAPH_RTTI_DECLARATION;
DetectionOutput() = default;
/// \brief Constructs a DetectionOutput operation
///
/// \param box_logits Box logits
/// \param class_preds Class predictions
/// \param proposals Proposals
/// \param aux_class_preds Auxilary class predictions
/// \param aux_box_preds Auxilary box predictions
/// \param attrs Detection Output attributes
DetectionOutput(const Output<Node>& box_logits,
const Output<Node>& class_preds,
const Output<Node>& proposals,
const Output<Node>& aux_class_preds,
const Output<Node>& aux_box_preds,
const DetectionOutputAttrs& attrs);
DetectionOutput() = default;
/// \brief Constructs a DetectionOutput operation
///
/// \param box_logits Box logits
/// \param class_preds Class predictions
/// \param proposals Proposals
/// \param aux_class_preds Auxilary class predictions
/// \param aux_box_preds Auxilary box predictions
/// \param attrs Detection Output attributes
DetectionOutput(const Output<Node>& box_logits,
const Output<Node>& class_preds,
const Output<Node>& proposals,
const Output<Node>& aux_class_preds,
const Output<Node>& aux_box_preds,
const DetectionOutputAttrs& attrs);
/// \brief Constructs a DetectionOutput operation
///
/// \param box_logits Box logits
/// \param class_preds Class predictions
/// \param proposals Proposals
/// \param attrs Detection Output attributes
DetectionOutput(const Output<Node>& box_logits,
const Output<Node>& class_preds,
const Output<Node>& proposals,
const DetectionOutputAttrs& attrs);
/// \brief Constructs a DetectionOutput operation
///
/// \param box_logits Box logits
/// \param class_preds Class predictions
/// \param proposals Proposals
/// \param attrs Detection Output attributes
DetectionOutput(const Output<Node>& box_logits,
const Output<Node>& class_preds,
const Output<Node>& proposals,
const DetectionOutputAttrs& attrs);
void validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const DetectionOutputAttrs& get_attrs() const { return m_attrs; }
bool visit_attributes(AttributeVisitor& visitor) override;
const DetectionOutputAttrs& get_attrs() const {
return m_attrs;
}
bool visit_attributes(AttributeVisitor& visitor) override;
private:
DetectionOutputAttrs m_attrs;
};
} // namespace v0
using v0::DetectionOutput;
} // namespace op
} // namespace ngraph
private:
DetectionOutputAttrs m_attrs;
};
} // namespace v0
using v0::DetectionOutput;
} // namespace op
} // namespace ngraph

View File

@ -18,44 +18,38 @@
#include <cstddef>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fft_base.hpp"
namespace ngraph
{
namespace op
{
namespace v7
{
/// \brief An operation DFT that computes the discrete Fourier transformation.
class NGRAPH_API DFT : public util::FFTBase
{
public:
NGRAPH_RTTI_DECLARATION;
DFT() = default;
namespace ngraph {
namespace op {
namespace v7 {
/// \brief An operation DFT that computes the discrete Fourier transformation.
class NGRAPH_API DFT : public util::FFTBase {
public:
NGRAPH_RTTI_DECLARATION;
DFT() = default;
/// \brief Constructs a DFT operation. DFT is performed for full size axes.
///
/// \param data Input data
/// \param axes Axes to perform DFT
DFT(const Output<Node>& data, const Output<Node>& axes);
/// \brief Constructs a DFT operation. DFT is performed for full size axes.
///
/// \param data Input data
/// \param axes Axes to perform DFT
DFT(const Output<Node>& data, const Output<Node>& axes);
/// \brief Constructs a DFT operation.
///
/// \param data Input data
/// \param axes Axes to perform DFT
/// \param signal_size Signal sizes for 'axes'
DFT(const Output<Node>& data,
const Output<Node>& axes,
const Output<Node>& signal_size);
/// \brief Constructs a DFT operation.
///
/// \param data Input data
/// \param axes Axes to perform DFT
/// \param signal_size Signal sizes for 'axes'
DFT(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v7
} // namespace op
} // namespace ngraph
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v7
} // namespace op
} // namespace ngraph

View File

@ -6,57 +6,50 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise division operation.
class NGRAPH_API Divide : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a division operation.
Divide()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Elementwise division operation.
class NGRAPH_API Divide : public util::BinaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a division operation.
Divide() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) {}
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_pythondiv() const {
return m_pythondiv;
}
void set_is_pythondiv(bool pythondiv) {
m_pythondiv = pythondiv;
}
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
bool m_pythondiv{true};
};
} // namespace v1
} // namespace op
} // namespace ngraph
protected:
bool m_pythondiv{true};
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -7,69 +7,66 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v7
{
/// \brief Einsum operation.
class NGRAPH_API Einsum : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v7 {
/// \brief Einsum operation.
class NGRAPH_API Einsum : public Op {
public:
NGRAPH_RTTI_DECLARATION;
Einsum() = default;
Einsum() = default;
///
/// \brief Constructs Einsum operation.
///
/// \param inputs Input nodes on which Einsum operation performs
/// contraction
///
/// \param equation Einstein summation convention
///
Einsum(const OutputVector& inputs, const std::string& equation);
///
/// \brief Constructs Einsum operation.
///
/// \param inputs Input nodes on which Einsum operation performs
/// contraction
///
/// \param equation Einstein summation convention
///
Einsum(const OutputVector& inputs, const std::string& equation);
void validate_and_infer_types() override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Get an equation of Einsum operation
///
/// \return Einsum equation
///
std::string get_equation() const { return m_equation; }
/// \brief Get an equation of Einsum operation
///
/// \return Einsum equation
///
std::string get_equation() const {
return m_equation;
}
/// \brief Check correctness of equation format and extract input subscripts
/// and output subscript
///
/// \param equation Equation to be parsed and checked
///
/// \param input_subscripts A vector of extracted input subscripts
///
/// \param output_subscript An output subscript
///
static void parse_equation(const std::string& equation,
std::vector<std::string>& input_subscripts,
std::string& output_subscript);
/// \brief Check correctness of equation format and extract input subscripts
/// and output subscript
///
/// \param equation Equation to be parsed and checked
///
/// \param input_subscripts A vector of extracted input subscripts
///
/// \param output_subscript An output subscript
///
static void parse_equation(const std::string& equation,
std::vector<std::string>& input_subscripts,
std::string& output_subscript);
/// \brief Extract labels (from subscript) that can be alphabetic letters or
/// ellipsis
///
/// \param subscript Subscript
///
/// \return A vector of extracted labels from the input subscript in the order
/// of appearence
///
static std::vector<std::string> extract_labels(const std::string& subscript);
/// \brief Extract labels (from subscript) that can be alphabetic letters or
/// ellipsis
///
/// \param subscript Subscript
///
/// \return A vector of extracted labels from the input subscript in the order
/// of appearence
///
static std::vector<std::string> extract_labels(const std::string& subscript);
private:
std::string m_equation;
};
} // namespace v7
} // namespace op
} // namespace ngraph
private:
std::string m_equation;
};
} // namespace v7
} // namespace op
} // namespace ngraph

View File

@ -7,40 +7,37 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Exponential Linear Unit
/// x < 0 => f(x) = alpha * (exp(x) - 1.)
/// x >= 0 => f(x) = x
///
class NGRAPH_API Elu : public ngraph::op::Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Exponential Linear Unit
/// x < 0 => f(x) = alpha * (exp(x) - 1.)
/// x >= 0 => f(x) = x
///
class NGRAPH_API Elu : public ngraph::op::Op {
public:
NGRAPH_RTTI_DECLARATION;
Elu() = default;
/// \brief Constructs an Elu operation.
///
/// \param data Input tensor
/// \param alpha Multiplier for negative values
Elu(const Output<Node>& data, const double alpha);
Elu() = default;
/// \brief Constructs an Elu operation.
///
/// \param data Input tensor
/// \param alpha Multiplier for negative values
Elu(const Output<Node>& data, const double alpha);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
double get_alpha() const { return m_alpha; }
double get_alpha() const {
return m_alpha;
}
private:
double m_alpha;
};
} // namespace v0
using v0::Elu;
} // namespace op
} // namespace ngraph
private:
double m_alpha;
};
} // namespace v0
using v0::Elu;
} // namespace op
} // namespace ngraph

View File

@ -7,76 +7,75 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/op/util/index_reduction.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingSegmentsSum : public Op
{
public:
static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a EmbeddingSegmentsSum operation.
EmbeddingSegmentsSum() = default;
/// \brief Constructs a EmbeddingSegmentsSum operation.
///
/// EmbeddingSegmentsSum constructs an output tensor by replacing every index in a
/// given
/// input tensor with a row (from the weights matrix) at that index
///
/// \param 'emb_table' tensor containing the embedding lookup table of the module of
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
/// \param 'indices' tensor of shape [num_indices] and of type T_IND. Required
/// \param `segment_ids` tensor of shape `[num_indices]` and of type *T_IND* with
/// indices
/// into the output Tensor. Values should be sorted and can be repeated. Required.
/// \param `num_segments` scalar of type *T_IND* indicating the number of segments.
/// Required.
/// \param 'default_index' scalar of type T_IND containing default index in
/// embedding
/// table to fill empty "bags". If not provided empty "bags"
/// are filled with zeros. Optional.
/// \param 'per_sample_weights' tensor of the same shape as indices and of type T.
/// Each value in this tensor are multiplied with each
/// value pooled from embedding table for each index. Optional.
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingSegmentsSum : public Op {
public:
static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs a EmbeddingSegmentsSum operation.
EmbeddingSegmentsSum() = default;
/// \brief Constructs a EmbeddingSegmentsSum operation.
///
/// EmbeddingSegmentsSum constructs an output tensor by replacing every index in a
/// given
/// input tensor with a row (from the weights matrix) at that index
///
/// \param 'emb_table' tensor containing the embedding lookup table of the module of
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
/// \param 'indices' tensor of shape [num_indices] and of type T_IND. Required
/// \param `segment_ids` tensor of shape `[num_indices]` and of type *T_IND* with
/// indices
/// into the output Tensor. Values should be sorted and can be repeated. Required.
/// \param `num_segments` scalar of type *T_IND* indicating the number of segments.
/// Required.
/// \param 'default_index' scalar of type T_IND containing default index in
/// embedding
/// table to fill empty "bags". If not provided empty "bags"
/// are filled with zeros. Optional.
/// \param 'per_sample_weights' tensor of the same shape as indices and of type T.
/// Each value in this tensor are multiplied with each
/// value pooled from embedding table for each index. Optional.
EmbeddingSegmentsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& segment_ids,
const Output<Node>& num_segments,
const Output<Node>& default_index,
const Output<Node>& per_sample_weights);
EmbeddingSegmentsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& segment_ids,
const Output<Node>& num_segments,
const Output<Node>& default_index,
const Output<Node>& per_sample_weights);
EmbeddingSegmentsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& segment_ids,
const Output<Node>& num_segments,
const Output<Node>& default_index);
EmbeddingSegmentsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& segment_ids,
const Output<Node>& num_segments,
const Output<Node>& default_index);
EmbeddingSegmentsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& segment_ids,
const Output<Node>& num_segments);
EmbeddingSegmentsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& segment_ids,
const Output<Node>& num_segments);
void validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor&) override { return true; }
bool visit_attributes(AttributeVisitor&) override {
return true;
}
private:
static constexpr int EMB_TABLE = 0;
static constexpr int INDICES = 1;
static constexpr int SEGMENT_IDS = 2;
static constexpr int NUM_SEGMENTS = 3;
static constexpr int DEFAULT_INDEX = 4;
static constexpr int PER_SAMPLE_WEIGHTS = 5;
};
} // namespace v3
using v3::EmbeddingSegmentsSum;
} // namespace op
} // namespace ngraph
private:
static constexpr int EMB_TABLE = 0;
static constexpr int INDICES = 1;
static constexpr int SEGMENT_IDS = 2;
static constexpr int NUM_SEGMENTS = 3;
static constexpr int DEFAULT_INDEX = 4;
static constexpr int PER_SAMPLE_WEIGHTS = 5;
};
} // namespace v3
using v3::EmbeddingSegmentsSum;
} // namespace op
} // namespace ngraph

View File

@ -8,57 +8,52 @@
#include "ngraph/op/util/embeddingbag_offsets_base.hpp"
#include "ngraph/op/util/index_reduction.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingBagOffsetsSum : public util::EmbeddingBagOffsetsBase
{
public:
static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsSum", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
EmbeddingBagOffsetsSum() = default;
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
///
/// EmbeddingBagOffsetsSum constructs an output tensor by replacing every index in a
/// given
/// input tensor with a row (from the weights matrix) at that index
///
/// \param emb_table tensor containing the embedding lookup table of the module of
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
/// \param tensor of shape [num_indices] and of type T_IND. Required
/// \param offsets tensor of shape [batch] and of type T_IND containing the starting
/// index positions of each "bag" in indices. Required.
/// \param default_index scalar of type T_IND containing default index in embedding
/// table to fill empty "bags". If not provided empty "bags"
/// are filled with zeros. Optional.
/// \param per_sample_weigths tensor of the same shape as indices and of type T.
/// Each value in this tensor are multiplied with each
/// value pooled from embedding table for each index. Optional.
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingBagOffsetsSum : public util::EmbeddingBagOffsetsBase {
public:
static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsSum", 3};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
EmbeddingBagOffsetsSum() = default;
/// \brief Constructs a EmbeddingBagOffsetsSum operation.
///
/// EmbeddingBagOffsetsSum constructs an output tensor by replacing every index in a
/// given
/// input tensor with a row (from the weights matrix) at that index
///
/// \param emb_table tensor containing the embedding lookup table of the module of
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
/// \param tensor of shape [num_indices] and of type T_IND. Required
/// \param offsets tensor of shape [batch] and of type T_IND containing the starting
/// index positions of each "bag" in indices. Required.
/// \param default_index scalar of type T_IND containing default index in embedding
/// table to fill empty "bags". If not provided empty "bags"
/// are filled with zeros. Optional.
/// \param per_sample_weigths tensor of the same shape as indices and of type T.
/// Each value in this tensor are multiplied with each
/// value pooled from embedding table for each index. Optional.
EmbeddingBagOffsetsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& offsets,
const Output<Node>& default_index,
const Output<Node>& per_sample_weights);
EmbeddingBagOffsetsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& offsets,
const Output<Node>& default_index,
const Output<Node>& per_sample_weights);
EmbeddingBagOffsetsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& offsets,
const Output<Node>& default_index);
EmbeddingBagOffsetsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& offsets,
const Output<Node>& default_index);
EmbeddingBagOffsetsSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& offsets);
EmbeddingBagOffsetsSum(const Output<Node>& emb_table, const Output<Node>& indices, const Output<Node>& offsets);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v3
using v3::EmbeddingBagOffsetsSum;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v3
using v3::EmbeddingBagOffsetsSum;
} // namespace op
} // namespace ngraph

View File

@ -8,44 +8,41 @@
#include "ngraph/op/util/embeddingbag_packed_base.hpp"
#include "ngraph/op/util/index_reduction.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingBagPackedSum : public util::EmbeddingBagPackedBase
{
public:
static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedSum", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a EmbeddingBagPackedSum operation.
EmbeddingBagPackedSum() = default;
/// \brief Constructs a EmbeddingBagPackedSum operation.
///
/// EmbeddingBagPackedSum constructs an output tensor by replacing every index in a
/// given
/// input tensor with a row (from the weights matrix) at that index
///
/// \param emb_table Tensor containing the embedding lookup table of the module of
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
/// \param indices Tensor of shape `[batch, indices_per_bag]` and of type *T_IND*.
/// Required.
/// \param per_sample_weigths tensor of the same shape as indices and of type T.
/// Each value in this tensor are multiplied with each
/// value pooled from embedding table for each index. Optional.
namespace ngraph {
namespace op {
namespace v3 {
/// \brief Returns embeddings for given indices
class NGRAPH_API EmbeddingBagPackedSum : public util::EmbeddingBagPackedBase {
public:
static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedSum", 3};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
/// \brief Constructs a EmbeddingBagPackedSum operation.
EmbeddingBagPackedSum() = default;
/// \brief Constructs a EmbeddingBagPackedSum operation.
///
/// EmbeddingBagPackedSum constructs an output tensor by replacing every index in a
/// given
/// input tensor with a row (from the weights matrix) at that index
///
/// \param emb_table Tensor containing the embedding lookup table of the module of
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
/// \param indices Tensor of shape `[batch, indices_per_bag]` and of type *T_IND*.
/// Required.
/// \param per_sample_weigths tensor of the same shape as indices and of type T.
/// Each value in this tensor are multiplied with each
/// value pooled from embedding table for each index. Optional.
EmbeddingBagPackedSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& per_sample_weights);
EmbeddingBagPackedSum(const Output<Node>& emb_table,
const Output<Node>& indices,
const Output<Node>& per_sample_weights);
EmbeddingBagPackedSum(const Output<Node>& emb_table, const Output<Node>& indices);
EmbeddingBagPackedSum(const Output<Node>& emb_table, const Output<Node>& indices);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v3
using v3::EmbeddingBagPackedSum;
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v3
using v3::EmbeddingBagPackedSum;
} // namespace op
} // namespace ngraph

View File

@ -6,13 +6,10 @@
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
// clang-format off
namespace ngraph {
namespace op {
namespace v1 {
// clang-format off
/// \brief Elementwise is-equal operation.
///
/// ## Inputs
@ -28,34 +25,27 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class NGRAPH_API Equal : public util::BinaryElementwiseComparison
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an equal operation.
Equal()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
// clang-format on
class NGRAPH_API Equal : public util::BinaryElementwiseComparison {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an equal operation.
Equal() : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY) {}
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph

View File

@ -6,32 +6,26 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise erf operation.
class NGRAPH_API Erf : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a floor operation.
Erf() = default;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Erf(const Output<Node>& arg);
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise erf operation.
class NGRAPH_API Erf : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a floor operation.
Erf() = default;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Erf(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Erf;
} // namespace op
} // namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Erf;
} // namespace op
} // namespace ngraph

View File

@ -6,34 +6,28 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise natural exponential (exp) operation.
class NGRAPH_API Exp : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise natural exponential (exp) operation.
class NGRAPH_API Exp : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an exponential operation.
Exp() = default;
/// \brief Constructs an exponential operation.
///
/// \param arg Node that produces the input tensor.
Exp(const Output<Node>& arg);
/// \brief Constructs an exponential operation.
Exp() = default;
/// \brief Constructs an exponential operation.
///
/// \param arg Node that produces the input tensor.
Exp(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Exp;
} // namespace op
} // namespace ngraph
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Exp;
} // namespace op
} // namespace ngraph

View File

@ -6,72 +6,69 @@
#include <cstddef>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief An operation ExperimentalDetectronDetectionOutput performs
/// non-maximum suppression to generate the detection output using
/// information on location and score predictions.
class NGRAPH_API ExperimentalDetectronDetectionOutput : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v6 {
/// \brief An operation ExperimentalDetectronDetectionOutput performs
/// non-maximum suppression to generate the detection output using
/// information on location and score predictions.
class NGRAPH_API ExperimentalDetectronDetectionOutput : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation
struct Attributes
{
// specifies score threshold
float score_threshold;
// specifies NMS threshold
float nms_threshold;
// specifies maximal delta of logarithms for width and height
float max_delta_log_wh;
// specifies number of detected classes
int64_t num_classes;
// specifies maximal number of detections per class
int64_t post_nms_count;
// specifies maximual number of detections per image
size_t max_detections_per_image;
// a flag specifies whether to delete background classes or not
// `true` means background classes should be deleted,
// `false` means background classes shouldn't be deleted.
bool class_agnostic_box_regression;
// specifies deltas of weights
std::vector<float> deltas_weights;
};
/// \brief Structure that specifies attributes of the operation
struct Attributes {
// specifies score threshold
float score_threshold;
// specifies NMS threshold
float nms_threshold;
// specifies maximal delta of logarithms for width and height
float max_delta_log_wh;
// specifies number of detected classes
int64_t num_classes;
// specifies maximal number of detections per class
int64_t post_nms_count;
// specifies maximual number of detections per image
size_t max_detections_per_image;
// a flag specifies whether to delete background classes or not
// `true` means background classes should be deleted,
// `false` means background classes shouldn't be deleted.
bool class_agnostic_box_regression;
// specifies deltas of weights
std::vector<float> deltas_weights;
};
ExperimentalDetectronDetectionOutput() = default;
/// \brief Constructs a ExperimentalDetectronDetectionOutput operation.
///
/// \param input_rois Input rois
/// \param input_deltas Input deltas
/// \param input_scores Input scores
/// \param input_im_info Input image info
/// \param attrs Attributes attributes
ExperimentalDetectronDetectionOutput(const Output<Node>& input_rois,
const Output<Node>& input_deltas,
const Output<Node>& input_scores,
const Output<Node>& input_im_info,
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
ExperimentalDetectronDetectionOutput() = default;
/// \brief Constructs a ExperimentalDetectronDetectionOutput operation.
///
/// \param input_rois Input rois
/// \param input_deltas Input deltas
/// \param input_scores Input scores
/// \param input_im_info Input image info
/// \param attrs Attributes attributes
ExperimentalDetectronDetectionOutput(const Output<Node>& input_rois,
const Output<Node>& input_deltas,
const Output<Node>& input_scores,
const Output<Node>& input_im_info,
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of the operation ExperimentalDetectronDetectionOutput
const Attributes& get_attrs() const { return m_attrs; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of the operation ExperimentalDetectronDetectionOutput
const Attributes& get_attrs() const {
return m_attrs;
}
private:
Attributes m_attrs;
};
} // namespace v6
} // namespace op
} // namespace ngraph
private:
Attributes m_attrs;
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -6,62 +6,59 @@
#include <cstdint>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief An operation ExperimentalDetectronGenerateProposalsSingleImage
/// computes ROIs and their scores based on input data.
class NGRAPH_API ExperimentalDetectronGenerateProposalsSingleImage : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v6 {
/// \brief An operation ExperimentalDetectronGenerateProposalsSingleImage
/// computes ROIs and their scores based on input data.
class NGRAPH_API ExperimentalDetectronGenerateProposalsSingleImage : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation
struct Attributes
{
// minimum box width & height
float min_size;
// specifies NMS threshold
float nms_threshold;
// number of top-n proposals after NMS
int64_t post_nms_count;
// number of top-n proposals before NMS
int64_t pre_nms_count;
};
/// \brief Structure that specifies attributes of the operation
struct Attributes {
// minimum box width & height
float min_size;
// specifies NMS threshold
float nms_threshold;
// number of top-n proposals after NMS
int64_t post_nms_count;
// number of top-n proposals before NMS
int64_t pre_nms_count;
};
ExperimentalDetectronGenerateProposalsSingleImage() = default;
/// \brief Constructs a ExperimentalDetectronGenerateProposalsSingleImage operation.
///
/// \param im_info Input image info
/// \param anchors Input anchors
/// \param deltas Input deltas
/// \param scores Input scores
/// \param attrs Operation attributes
ExperimentalDetectronGenerateProposalsSingleImage(const Output<Node>& im_info,
const Output<Node>& anchors,
const Output<Node>& deltas,
const Output<Node>& scores,
const Attributes& attrs);
ExperimentalDetectronGenerateProposalsSingleImage() = default;
/// \brief Constructs a ExperimentalDetectronGenerateProposalsSingleImage operation.
///
/// \param im_info Input image info
/// \param anchors Input anchors
/// \param deltas Input deltas
/// \param scores Input scores
/// \param attrs Operation attributes
ExperimentalDetectronGenerateProposalsSingleImage(const Output<Node>& im_info,
const Output<Node>& anchors,
const Output<Node>& deltas,
const Output<Node>& scores,
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const Attributes& get_attrs() const { return m_attrs; }
const Attributes& get_attrs() const {
return m_attrs;
}
private:
Attributes m_attrs;
};
} // namespace v6
} // namespace op
} // namespace ngraph
private:
Attributes m_attrs;
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -6,65 +6,62 @@
#include <cstdint>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief An operation ExperimentalDetectronPriorGridGenerator generates prior
/// grids of specified sizes.
class NGRAPH_API ExperimentalDetectronPriorGridGenerator : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v6 {
/// \brief An operation ExperimentalDetectronPriorGridGenerator generates prior
/// grids of specified sizes.
class NGRAPH_API ExperimentalDetectronPriorGridGenerator : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation
struct Attributes
{
// Specifies whether the output tensor should be 2D or 4D
// `true` means the output tensor should be 2D tensor,
// `false` means the output tensor should be 4D tensor.
bool flatten;
// Specifies number of cells of the generated grid with respect to height.
int64_t h;
// Specifies number of cells of the generated grid with respect to width.
int64_t w;
// Specifies the step of generated grid with respect to x coordinate
float stride_x;
// Specifies the step of generated grid with respect to y coordinate
float stride_y;
};
/// \brief Structure that specifies attributes of the operation
struct Attributes {
// Specifies whether the output tensor should be 2D or 4D
// `true` means the output tensor should be 2D tensor,
// `false` means the output tensor should be 4D tensor.
bool flatten;
// Specifies number of cells of the generated grid with respect to height.
int64_t h;
// Specifies number of cells of the generated grid with respect to width.
int64_t w;
// Specifies the step of generated grid with respect to x coordinate
float stride_x;
// Specifies the step of generated grid with respect to y coordinate
float stride_y;
};
ExperimentalDetectronPriorGridGenerator() = default;
/// \brief Constructs a ExperimentalDetectronDetectionOutput operation.
///
/// \param priors Input priors
/// \param feature_map Input feature map
/// \param im_data Image data
/// \param attrs attributes
ExperimentalDetectronPriorGridGenerator(const Output<Node>& priors,
const Output<Node>& feature_map,
const Output<Node>& im_data,
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
ExperimentalDetectronPriorGridGenerator() = default;
/// \brief Constructs a ExperimentalDetectronDetectionOutput operation.
///
/// \param priors Input priors
/// \param feature_map Input feature map
/// \param im_data Image data
/// \param attrs attributes
ExperimentalDetectronPriorGridGenerator(const Output<Node>& priors,
const Output<Node>& feature_map,
const Output<Node>& im_data,
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of this operation.
const Attributes& get_attrs() const { return m_attrs; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of this operation.
const Attributes& get_attrs() const {
return m_attrs;
}
private:
Attributes m_attrs;
private:
Attributes m_attrs;
void validate();
};
} // namespace v6
} // namespace op
} // namespace ngraph
void validate();
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -7,58 +7,53 @@
#include <cstddef>
#include <cstdint>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief An operation ExperimentalDetectronROIFeatureExtractor
/// is the ROIAlign operation applied over a feature pyramid.
class NGRAPH_API ExperimentalDetectronROIFeatureExtractor : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v6 {
/// \brief An operation ExperimentalDetectronROIFeatureExtractor
/// is the ROIAlign operation applied over a feature pyramid.
class NGRAPH_API ExperimentalDetectronROIFeatureExtractor : public Op {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Structure that specifies attributes of the operation
struct Attributes
{
int64_t output_size;
int64_t sampling_ratio;
std::vector<int64_t> pyramid_scales;
bool aligned;
};
/// \brief Structure that specifies attributes of the operation
struct Attributes {
int64_t output_size;
int64_t sampling_ratio;
std::vector<int64_t> pyramid_scales;
bool aligned;
};
ExperimentalDetectronROIFeatureExtractor() = default;
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
///
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
/// \param attrs Operation attributes
ExperimentalDetectronROIFeatureExtractor(const OutputVector& args,
const Attributes& attrs);
ExperimentalDetectronROIFeatureExtractor() = default;
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
///
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
/// \param attrs Operation attributes
ExperimentalDetectronROIFeatureExtractor(const OutputVector& args, const Attributes& attrs);
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
///
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
/// \param attrs Operation attributes
ExperimentalDetectronROIFeatureExtractor(const NodeVector& args,
const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
/// \brief Constructs a ExperimentalDetectronROIFeatureExtractor operation.
///
/// \param args Inputs of ExperimentalDetectronROIFeatureExtractor
/// \param attrs Operation attributes
ExperimentalDetectronROIFeatureExtractor(const NodeVector& args, const Attributes& attrs);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of the operation.
const Attributes& get_attrs() const { return m_attrs; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Returns attributes of the operation.
const Attributes& get_attrs() const {
return m_attrs;
}
private:
Attributes m_attrs;
};
} // namespace v6
} // namespace op
} // namespace ngraph
private:
Attributes m_attrs;
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -6,44 +6,40 @@
#include <cstdint>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief An operation ExperimentalDetectronTopKROIs, according to the repository
/// is TopK operation applied to probabilities of input ROIs.
class NGRAPH_API ExperimentalDetectronTopKROIs : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v6 {
/// \brief An operation ExperimentalDetectronTopKROIs, according to the repository
/// is TopK operation applied to probabilities of input ROIs.
class NGRAPH_API ExperimentalDetectronTopKROIs : public Op {
public:
NGRAPH_RTTI_DECLARATION;
ExperimentalDetectronTopKROIs() = default;
/// \brief Constructs a ExperimentalDetectronTopKROIs operation.
///
/// \param input_rois Input rois
/// \param rois_probs Probabilities for input rois
/// \param max_rois Maximal numbers of output rois
ExperimentalDetectronTopKROIs(const Output<Node>& input_rois,
const Output<Node>& rois_probs,
size_t max_rois = 0);
ExperimentalDetectronTopKROIs() = default;
/// \brief Constructs a ExperimentalDetectronTopKROIs operation.
///
/// \param input_rois Input rois
/// \param rois_probs Probabilities for input rois
/// \param max_rois Maximal numbers of output rois
ExperimentalDetectronTopKROIs(const Output<Node>& input_rois, const Output<Node>& rois_probs, size_t max_rois = 0);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_max_rois() const { return m_max_rois; }
size_t get_max_rois() const {
return m_max_rois;
}
private:
size_t m_max_rois;
};
} // namespace v6
} // namespace op
} // namespace ngraph
private:
size_t m_max_rois;
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -6,55 +6,66 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
class NGRAPH_API ExtractImagePatches : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v3 {
class NGRAPH_API ExtractImagePatches : public Op {
public:
NGRAPH_RTTI_DECLARATION;
ExtractImagePatches() = default;
/// \brief Constructs a ExtractImagePatches operation
///
/// \param data 4-D Input data to extract image patches
/// \param sizes Patch size in the format of [size_rows, size_cols]
/// \param strides Patch movement stride in the format of [stride_rows, stride_cols]
/// \param rates Element seleciton rate for creating a patch. in the format of
/// [rate_rows, rate_cols]
/// \param auto_pad Padding type. it can be any value from
/// valid, same_lower, same_upper
ExtractImagePatches(const Output<Node>& image,
const Shape& sizes,
const Strides& strides,
const Shape& rates,
const PadType& auto_pad);
ExtractImagePatches() = default;
/// \brief Constructs a ExtractImagePatches operation
///
/// \param data 4-D Input data to extract image patches
/// \param sizes Patch size in the format of [size_rows, size_cols]
/// \param strides Patch movement stride in the format of [stride_rows, stride_cols]
/// \param rates Element seleciton rate for creating a patch. in the format of
/// [rate_rows, rate_cols]
/// \param auto_pad Padding type. it can be any value from
/// valid, same_lower, same_upper
ExtractImagePatches(const Output<Node>& image,
const Shape& sizes,
const Strides& strides,
const Shape& rates,
const PadType& auto_pad);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const Shape& get_sizes() const { return m_patch_sizes; }
void set_sizes(const Shape& sizes) { m_patch_sizes = sizes; }
const Strides& get_strides() const { return m_patch_movement_strides; }
void set_strides(const Strides& strides) { m_patch_movement_strides = strides; }
const Shape& get_rates() const { return m_patch_selection_rates; }
void set_rates(const Shape& rates) { m_patch_selection_rates = rates; }
const PadType& get_auto_pad() const { return m_padding; }
void set_auto_pad(PadType& padding) { m_padding = padding; }
const Shape& get_sizes() const {
return m_patch_sizes;
}
void set_sizes(const Shape& sizes) {
m_patch_sizes = sizes;
}
const Strides& get_strides() const {
return m_patch_movement_strides;
}
void set_strides(const Strides& strides) {
m_patch_movement_strides = strides;
}
const Shape& get_rates() const {
return m_patch_selection_rates;
}
void set_rates(const Shape& rates) {
m_patch_selection_rates = rates;
}
const PadType& get_auto_pad() const {
return m_padding;
}
void set_auto_pad(PadType& padding) {
m_padding = padding;
}
private:
Shape m_patch_sizes;
Strides m_patch_movement_strides;
Shape m_patch_selection_rates;
PadType m_padding;
};
} // namespace v3
using v3::ExtractImagePatches;
} // namespace op
} // namespace ngraph
private:
Shape m_patch_sizes;
Strides m_patch_movement_strides;
Shape m_patch_selection_rates;
PadType m_padding;
};
} // namespace v3
using v3::ExtractImagePatches;
} // namespace op
} // namespace ngraph

View File

@ -8,71 +8,70 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
///
/// \brief Class performing element-wise linear quantization.
///
/// \note Input floating point values are quantized into a discrete
/// set of floating point values.
///
/// \paragraph Implementation This class creates a node which performs the following
/// operation:
///
/// round((data - input_low) / (input_high - input_low) * (levels-1)) /
/// (levels-1) * (output_high - output_low) + output_low
///
///
class NGRAPH_API FakeQuantize : public ngraph::op::Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v0 {
///
/// \brief Class performing element-wise linear quantization.
///
/// \note Input floating point values are quantized into a discrete
/// set of floating point values.
///
/// \paragraph Implementation This class creates a node which performs the following
/// operation:
///
/// round((data - input_low) / (input_high - input_low) * (levels-1)) /
/// (levels-1) * (output_high - output_low) + output_low
///
///
class NGRAPH_API FakeQuantize : public ngraph::op::Op {
public:
NGRAPH_RTTI_DECLARATION;
FakeQuantize();
///
/// \brief Constructs a FakeQuantize operation node.
///
/// \param[in] data The input data tensor.
/// \param[in] input_low The minimum limit for input values.
/// \param[in] input_high The maximum limit for input values.
/// \param[in] output_low The minimum quantized value.
/// \param[in] output_high The maximum quantized value.
/// \param[in] levels The number of quantization levels.
/// \param[in] auto_broadcast AutoBroadcast mode to be used for broadcasting
/// limit values
///
FakeQuantize(const Output<Node>& data,
const Output<Node>& input_low,
const Output<Node>& input_high,
const Output<Node>& output_low,
const Output<Node>& output_high,
std::size_t levels,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
FakeQuantize();
///
/// \brief Constructs a FakeQuantize operation node.
///
/// \param[in] data The input data tensor.
/// \param[in] input_low The minimum limit for input values.
/// \param[in] input_high The maximum limit for input values.
/// \param[in] output_low The minimum quantized value.
/// \param[in] output_high The maximum quantized value.
/// \param[in] levels The number of quantization levels.
/// \param[in] auto_broadcast AutoBroadcast mode to be used for broadcasting
/// limit values
///
FakeQuantize(const Output<Node>& data,
const Output<Node>& input_low,
const Output<Node>& input_high,
const Output<Node>& output_low,
const Output<Node>& output_high,
std::size_t levels,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
std::size_t get_levels() const { return m_levels; }
void set_levels(std::size_t levels) { m_levels = levels; }
const AutoBroadcastSpec& get_auto_broadcast() const { return m_auto_broadcast; }
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast)
{
m_auto_broadcast = auto_broadcast;
}
std::size_t get_levels() const {
return m_levels;
}
void set_levels(std::size_t levels) {
m_levels = levels;
}
const AutoBroadcastSpec& get_auto_broadcast() const {
return m_auto_broadcast;
}
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast) {
m_auto_broadcast = auto_broadcast;
}
private:
std::size_t m_levels;
AutoBroadcastSpec m_auto_broadcast = op::AutoBroadcastType::NUMPY;
};
} // namespace v0
using v0::FakeQuantize;
} // namespace op
} // namespace ngraph
private:
std::size_t m_levels;
AutoBroadcastSpec m_auto_broadcast = op::AutoBroadcastType::NUMPY;
};
} // namespace v0
using v0::FakeQuantize;
} // namespace op
} // namespace ngraph

View File

@ -6,32 +6,26 @@
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Elementwise floor operation.
class NGRAPH_API Floor : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a floor operation.
Floor() = default;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Floor(const Output<Node>& arg);
namespace ngraph {
namespace op {
namespace v0 {
/// \brief Elementwise floor operation.
class NGRAPH_API Floor : public util::UnaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a floor operation.
Floor() = default;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Floor(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Floor;
} // namespace op
} // namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0
using v0::Floor;
} // namespace op
} // namespace ngraph

View File

@ -8,46 +8,39 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise FloorMod operation.
///
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Elementwise FloorMod operation.
///
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic {
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an uninitialized addition operation
FloorMod()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
/// \brief Constructs an uninitialized addition operation
FloorMod() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
/// \brief Constructs an Floor Mod operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
FloorMod(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
/// \brief Constructs an Floor Mod operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
FloorMod(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1
using v1::FloorMod;
} // namespace op
} // namespace ngraph
using v1::FloorMod;
} // namespace op
} // namespace ngraph

View File

@ -6,87 +6,74 @@
#include "ngraph/op/util/gather_base.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase
{
public:
NGRAPH_RTTI_DECLARATION;
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
Gather(const Output<Node>& params,
const Output<Node>& indices,
const Output<Node>& axis);
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase {
public:
NGRAPH_RTTI_DECLARATION;
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axis);
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t get_axis() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t get_axis() const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v1
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v1
namespace v7
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase
{
public:
NGRAPH_RTTI_DECLARATION;
Gather() = default;
namespace v7 {
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase {
public:
NGRAPH_RTTI_DECLARATION;
Gather() = default;
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors.
/// If batch_dims = 0 Gather v7 is identical to Gather v1.
Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors.
/// If batch_dims = 0 Gather v7 is identical to Gather v1.
Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
int64_t get_batch_dims() const;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
int64_t get_batch_dims() const;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v7
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v7
namespace v8
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase
{
public:
NGRAPH_RTTI_DECLARATION;
Gather() = default;
namespace v8 {
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase {
public:
NGRAPH_RTTI_DECLARATION;
Gather() = default;
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors.
Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors.
Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
int64_t get_batch_dims() const;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
int64_t get_batch_dims() const;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v8
} // namespace op
} // namespace ngraph
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v8
} // namespace op
} // namespace ngraph

View File

@ -6,39 +6,34 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief GatherElements operation
///
class NGRAPH_API GatherElements : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
GatherElements() = default;
namespace ngraph {
namespace op {
namespace v6 {
/// \brief GatherElements operation
///
class NGRAPH_API GatherElements : public Op {
public:
NGRAPH_RTTI_DECLARATION;
GatherElements() = default;
/// \brief Constructs a GatherElements operation.
///
/// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements
/// \param axis specifies axis along which indices are specified
GatherElements(const Output<Node>& data,
const Output<Node>& indices,
const int64_t axis);
/// \brief Constructs a GatherElements operation.
///
/// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements
/// \param axis specifies axis along which indices are specified
GatherElements(const Output<Node>& data, const Output<Node>& indices, const int64_t axis);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_axis() const { return m_axis; }
int64_t get_axis() const {
return m_axis;
}
private:
int64_t m_axis;
};
} // namespace v6
} // namespace op
} // namespace ngraph
private:
int64_t m_axis;
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -6,40 +6,35 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v5
{
/// \brief GatherND operation
///
class NGRAPH_API GatherND : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
GatherND() = default;
namespace ngraph {
namespace op {
namespace v5 {
/// \brief GatherND operation
///
class NGRAPH_API GatherND : public Op {
public:
NGRAPH_RTTI_DECLARATION;
GatherND() = default;
/// \brief Constructs a GatherND operation.
///
/// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements
/// or slices from data
/// \param batch_dims Specifies a number of batch dimensions
GatherND(const Output<Node>& data,
const Output<Node>& indices,
const size_t batch_dims = 0);
/// \brief Constructs a GatherND operation.
///
/// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements
/// or slices from data
/// \param batch_dims Specifies a number of batch dimensions
GatherND(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims = 0);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_batch_dims() const { return m_batch_dims; }
size_t get_batch_dims() const {
return m_batch_dims;
}
private:
size_t m_batch_dims;
};
} // namespace v5
} // namespace op
} // namespace ngraph
private:
size_t m_batch_dims;
};
} // namespace v5
} // namespace op
} // namespace ngraph

View File

@ -6,38 +6,33 @@
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Generates the complete beams from the ids per each step and the parent beam
/// ids.
class NGRAPH_API GatherTree : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Generates the complete beams from the ids per each step and the parent beam
/// ids.
class NGRAPH_API GatherTree : public Op {
public:
NGRAPH_RTTI_DECLARATION;
GatherTree() = default;
/// \param step_ids Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with
/// indices from per each step
/// \param parent_idx Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with
/// parent beam indices
/// \param max_seq_len Tensor of shape [BATCH_SIZE] with maximum lengths for each
/// sequence in the batch
/// \param end_token Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
GatherTree(const Output<Node>& step_ids,
const Output<Node>& parent_idx,
const Output<Node>& max_seq_len,
const Output<Node>& end_token);
GatherTree() = default;
/// \param step_ids Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with
/// indices from per each step
/// \param parent_idx Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with
/// parent beam indices
/// \param max_seq_len Tensor of shape [BATCH_SIZE] with maximum lengths for each
/// sequence in the batch
/// \param end_token Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
GatherTree(const Output<Node>& step_ids,
const Output<Node>& parent_idx,
const Output<Node>& max_seq_len,
const Output<Node>& end_token);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v1
} // namespace op
} // namespace ngraph

Some files were not shown because too many files have changed in this diff Show More